diff --git a/docs/api/treatment_effect.md b/docs/api/treatment_effect.md index 56af0814..a8c47818 100644 --- a/docs/api/treatment_effect.md +++ b/docs/api/treatment_effect.md @@ -7,9 +7,9 @@ The `TreatmentEffect` module is designed for analyzing treatment effects within :toctree: _autosummary :recursive: - medmodels.treatment_effect.treatment_effect.TreatmentEffect - medmodels.treatment_effect.builder.TreatmentEffectBuilder - medmodels.treatment_effect.estimate.Estimate - medmodels.treatment_effect.report.Report + medmodels.treatment_effect.treatment_effect + medmodels.treatment_effect.builder + medmodels.treatment_effect.estimate + medmodels.treatment_effect.report medmodels.treatment_effect.temporal_analysis ``` diff --git a/docs/conf.py b/docs/conf.py index 89ee37a0..c57d42c3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -69,6 +69,7 @@ "private-members": False, "inherited-members": True, "show-inheritance": True, + "ignore-module-all": False, } autosummary_generate = True @@ -170,7 +171,7 @@ # Local Sphinx extensions -def setup(app: Sphinx): +def setup(app: Sphinx) -> None: """Add custom directives and transformations to Sphinx.""" from myst_parser._docs import ( DirectiveDoc, diff --git a/docs/developer_guide/docstrings.md b/docs/developer_guide/docstrings.md index 5668959c..2874703d 100644 --- a/docs/developer_guide/docstrings.md +++ b/docs/developer_guide/docstrings.md @@ -29,21 +29,14 @@ The summary line provides a brief description of what the function or class does Example: -```python -def example_function(param1, param2): - """This function performs an example operation. - - Args: - param1 (int): The first parameter. - param2 (str): The second parameter. - - Returns: - bool: True if successful, False otherwise. - """ +```{literalinclude} example_docstrings.py +:lines: 1-5, 6-38 +:emphasize-lines: 13 ``` -:::{note} -The summary line cannot contain line breaks. When writing docstrings, the max line length setting can be exceeded without triggering linting errors. +:::{admonition} No Linebreaks In Summary Line +:type: note +The summary line cannot contain line breaks and must not exceed the maximum line length. Ensure that the summary line provides a brief description of the function's purpose. Further explanation and details can be given in the [Description](#description) text block underneath. ::: ### Description @@ -221,24 +214,24 @@ Document parameters under the `Args` section. Each parameter should include its Example: -```python -def example_function(param1, param2, param3): - """This function performs an example operation. - - Args: - param1 (int): The first parameter. - param2 (Union[str, List[str], Dict[str, Any]]): The second parameter, which can be - a string, a list of strings, or a dictionary with string keys and any type of - values. This parameter is used to demonstrate a long type definition. - param3 (str): The third parameter, which has a long description that needs to be - broken into multiple lines for better readability. This parameter is used - to show how to write long descriptions and how to indent them properly in - the docstring. - """ +```{literalinclude} example_docstrings.py +:lines: 1-5, 6-38 +:emphasize-lines: 18-26 ``` -:::{note} -Type definitions of parameters cannot have a line break. You can not put a line break inside `(Union[str, List[str], Dict[str, Any]])`. When writing docstrings, the max line length setting can be exceeded without triggering linting errors. +:::{admonition} No Linebreak in Type Definitions +:type: note + +When writing type definitions in argument docstrings, avoid placing line breaks inside the type annotations. For instance, complex types like `(Union[str, List[str], Dict[str, Any]])` should appear on a single line without splitting across lines. This ensures the type definition remains clear and avoids parsing issues. + +If a type definition exceeds the maximum line length limit, deactivate the line length rule for that doctstring using [`# noqa: E501`](https://docs.astral.sh/ruff/rules/line-too-long/#error-suppression) after the closing `"""`. This keeps the type annotation intact while preventing the linter from flagging the long line as an error. +::: + +:::{admonition} No Boolean Argument Types +:type: note +Avoid using booleans as function arguments due to the "boolean trap". Booleans reduce code clarity, as `True` or `False` doesn't explain meaning. They also limit future flexibility — if more than two options are needed, changing the argument type can cause breaking changes. Instead, use an enum or a more descriptive type for better clarity and flexibility. + +For more information, you can refer to [Adam Johnson's article](https://adamj.eu/tech/2021/07/10/python-type-hints-how-to-avoid-the-boolean-trap/) which discusses this in detail and provides examples of how to avoid the boolean trap. ::: ### Return Types @@ -256,6 +249,11 @@ def example_function(param1, param2): """ ``` +:::{admonition} Don't Document `None` Return Value +:type: note +Don't add a `Returns` section when the function has no return value (`def fun() -> None`) +::: + ### Examples Provide examples under the `Examples` section by using Sphinx's `.. code-block::` directive within the docstrings. @@ -286,8 +284,6 @@ The second block shows the return value when executing the code. The output valu ``` - - Full Example: ```python diff --git a/docs/developer_guide/example_docstrings.py b/docs/developer_guide/example_docstrings.py new file mode 100644 index 00000000..8c0861f1 --- /dev/null +++ b/docs/developer_guide/example_docstrings.py @@ -0,0 +1,60 @@ +"""Example module with docstrings for the developer guide.""" + +from __future__ import annotations + +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + + +def example_function_args( + param1: int, + param2: Union[str, int], + optional_param: Optional[List[str]] = None, + *args: Union[float, str], + **kwargs: Dict[str, Any], +) -> Tuple[bool, List[str]]: + """Example function with PEP 484 type annotations and PEP 563 future annotations. + + This function shows how to define and document typing for different kinds of + arguments, including positional, optional, variable-length args, and keyword args. + + Args: + param1 (int): A required integer parameter. + param2 (Union[str, int]): A parameter that can be either a string or an integer. + optional_param (Optional[List[str]], optional): An optional parameter that + accepts a list of strings. Defaults to None if not provided. + *args (Union[float, str]): Variable length argument list that accepts floats or + strings. + **kwargs (Dict[str, Any]): Arbitrary keyword arguments as a dictionary of string + keys and values of any type. + + Returns: + Tuple[bool, List[str]]: A tuple containing: + - bool: Always True for this example. + - List[str]: A list with a single string describing the received arguments. + """ + result = ( + f"Received: param1={param1}, param2={param2}, optional_param={optional_param}, " + f"args={args}, kwargs={kwargs}" + ) + + return True, [result] + + +def example_generator(n: int) -> Iterator[int]: + """Generators have a ``Yields`` section instead of a ``Returns`` section. + + Args: + n (int): The upper limit of the range to generate, from 0 to `n` - 1. + + Yields: + int: The next number in the range of 0 to `n` - 1. + + Examples: + Examples should be written in doctest format, and should illustrate how + to use the function. + + >>> print([i for i in example_generator(4)]) + [0, 1, 2, 3] + + """ + yield from range(n) diff --git a/docs/serve_docs.py b/docs/serve_docs.py index c2c2b7ce..b5952a02 100644 --- a/docs/serve_docs.py +++ b/docs/serve_docs.py @@ -1,15 +1,17 @@ #!/usr/bin/env python3 +import logging import os import subprocess from pathlib import Path from livereload import Server, shell +logging.basicConfig(level=logging.INFO) -def setup_live_docs_server(): - """ - Set up and run a live documentation server. + +def setup_live_docs_server() -> None: + """Set up and run a live documentation server. This function initializes a server that automatically rebuilds and refreshes the documentation when source files are modified. @@ -25,7 +27,7 @@ def setup_live_docs_server(): rebuild_cmd = shell("make html", cwd=str(script_path)) # Initially build the documentation - print("Building initial documentation...") + logging.info("Building initial documentation...") subprocess.run(["make", "html"], cwd=str(script_path), check=True) # File patterns to watch for changes @@ -48,7 +50,7 @@ def setup_live_docs_server(): if __name__ == "__main__": - print("Starting live documentation server...") - print("Access the docs at http://localhost:5500") - print("Press Ctrl+C to stop the server.") + logging.info("Starting live documentation server...") + logging.info("Access the docs at http://localhost:5500") + logging.info("Press Ctrl+C to stop the server.") setup_live_docs_server() diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index 4762b4a5..543c4582 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -1,5 +1,3 @@ -from __future__ import annotations - from enum import Enum from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union diff --git a/medmodels/medrecord/__init__.py b/medmodels/medrecord/__init__.py index 3a9f3f6f..01913f8f 100644 --- a/medmodels/medrecord/__init__.py +++ b/medmodels/medrecord/__init__.py @@ -20,23 +20,23 @@ from medmodels.medrecord.schema import AttributeType, GroupSchema, Schema __all__ = [ - "MedRecord", - "String", - "Int", - "Float", + "Any", + "AttributeType", "Bool", "DateTime", + "EdgeIndex", + "EdgeOperation", + "Float", + "GroupSchema", + "Int", + "MedRecord", + "NodeIndex", + "NodeOperation", "Null", - "Any", - "Union", "Option", - "AttributeType", "Schema", - "GroupSchema", - "node", + "String", + "Union", "edge", - "NodeIndex", - "EdgeIndex", - "NodeOperation", - "EdgeOperation", + "node", ] diff --git a/medmodels/medrecord/_overview.py b/medmodels/medrecord/_overview.py index e22a16a1..f4876461 100644 --- a/medmodels/medrecord/_overview.py +++ b/medmodels/medrecord/_overview.py @@ -184,7 +184,7 @@ def prettify_table( info_order = ["min", "max", "mean", "values"] - for group in data.keys(): + for group in data: # determine longest group name and count lengths[0] = max(len(str(group)), lengths[0]) diff --git a/medmodels/medrecord/builder.py b/medmodels/medrecord/builder.py index ab014de1..8c1eaaa4 100644 --- a/medmodels/medrecord/builder.py +++ b/medmodels/medrecord/builder.py @@ -5,8 +5,9 @@ if TYPE_CHECKING: from typing_extensions import TypeIs + from medmodels.medrecord.schema import Schema + import medmodels as mm -from medmodels.medrecord.schema import Schema from medmodels.medrecord.types import ( EdgeTuple, Group, @@ -157,7 +158,7 @@ def add_edges( return self def add_group( - self, group: Group, *, nodes: List[NodeIndex] = [] + self, group: Group, *, nodes: Optional[List[NodeIndex]] = None ) -> MedRecordBuilder: """Adds a group to the builder with an optional list of nodes. @@ -168,6 +169,8 @@ def add_group( Returns: MedRecordBuilder: The current instance of the builder. """ + if nodes is None: + nodes = [] self._groups[group] = {"nodes": nodes, "edges": []} return self diff --git a/medmodels/medrecord/datatype.py b/medmodels/medrecord/datatype.py index 2dcfdea3..daa8f3b2 100644 --- a/medmodels/medrecord/datatype.py +++ b/medmodels/medrecord/datatype.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from medmodels._medmodels import ( PyAny, @@ -36,7 +36,7 @@ ] -class DataType(metaclass=ABCMeta): +class DataType(ABC): @abstractmethod def _inner(self) -> PyDataType: ... @@ -53,25 +53,24 @@ def __eq__(self, value: object) -> bool: ... def _from_pydatatype(datatype: PyDataType) -> DataType: if isinstance(datatype, PyString): return String() - elif isinstance(datatype, PyInt): + if isinstance(datatype, PyInt): return Int() - elif isinstance(datatype, PyFloat): + if isinstance(datatype, PyFloat): return Float() - elif isinstance(datatype, PyBool): + if isinstance(datatype, PyBool): return Bool() - elif isinstance(datatype, PyDateTime): + if isinstance(datatype, PyDateTime): return DateTime() - elif isinstance(datatype, PyNull): + if isinstance(datatype, PyNull): return Null() - elif isinstance(datatype, PyAny): + if isinstance(datatype, PyAny): return Any() - elif isinstance(datatype, PyUnion): + if isinstance(datatype, PyUnion): return Union( DataType._from_pydatatype(datatype.dtype1), DataType._from_pydatatype(datatype.dtype2), ) - else: - return Option(DataType._from_pydatatype(datatype.dtype)) + return Option(DataType._from_pydatatype(datatype.dtype)) class String(DataType): @@ -212,8 +211,9 @@ class Union(DataType): def __init__(self, *dtypes: DataType) -> None: if len(dtypes) < 2: - raise ValueError("Union must have at least two arguments") - elif len(dtypes) == 2: + msg = "Union must have at least two arguments" + raise ValueError(msg) + if len(dtypes) == 2: self._union = PyUnion(dtypes[0]._inner(), dtypes[1]._inner()) else: self._union = PyUnion(dtypes[0]._inner(), Union(*dtypes[1:])._inner()) diff --git a/medmodels/medrecord/indexers.py b/medmodels/medrecord/indexers.py index b76404e1..439e1335 100644 --- a/medmodels/medrecord/indexers.py +++ b/medmodels/medrecord/indexers.py @@ -92,7 +92,8 @@ def __getitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node(self._medrecord.nodes) @@ -110,7 +111,7 @@ def __getitem__( ): attributes = self._medrecord._medrecord.node(index_selection) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( attribute_selection @@ -119,7 +120,7 @@ def __getitem__( self._medrecord.select_nodes(index_selection) ) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, slice) and is_medrecord_attribute( attribute_selection @@ -129,11 +130,12 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(self._medrecord.nodes) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if is_node_index(index_selection) and isinstance(attribute_selection, list): return { @@ -148,7 +150,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, NodeOperation) and isinstance( @@ -160,7 +162,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, slice) and isinstance(attribute_selection, list): @@ -169,13 +171,14 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(self._medrecord.nodes) return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if is_node_index(index_selection) and isinstance(attribute_selection, slice): @@ -184,7 +187,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node([index_selection])[index_selection] @@ -194,7 +198,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node(index_selection) @@ -206,7 +211,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node( self._medrecord.select_nodes(index_selection) @@ -223,9 +229,11 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node(self._medrecord.nodes) + return None @overload def __setitem__( @@ -260,19 +268,22 @@ def __setitem__( ) -> None: if is_node_index(key): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes([key], value) if isinstance(key, list): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes(key, value) if isinstance(key, NodeOperation): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.select_nodes(key), value @@ -280,10 +291,12 @@ def __setitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.nodes, value @@ -295,7 +308,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( [index_selection], attribute_selection, value @@ -305,7 +319,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( index_selection, attribute_selection, value @@ -315,7 +330,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( self._medrecord.select_nodes(index_selection), @@ -331,10 +347,12 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( self._medrecord.nodes, @@ -344,38 +362,41 @@ def __setitem__( if is_node_index(index_selection) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( [index_selection], attribute, value ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( index_selection, attribute, value ) - return + return None if isinstance(index_selection, NodeOperation) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( self._medrecord.select_nodes(index_selection), attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -383,17 +404,19 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( self._medrecord.nodes, attribute, value ) - return + return None if is_node_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -401,23 +424,25 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node([index_selection])[ index_selection ] - for attribute in attributes.keys(): + for attribute in attributes: self._medrecord._medrecord.update_node_attribute( [index_selection], attribute, value, ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, slice): if ( @@ -425,20 +450,22 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(index_selection) - for node in attributes.keys(): - for attribute in attributes[node].keys(): + for node in attributes: + for attribute in attributes[node]: self._medrecord._medrecord.update_node_attribute( [node], attribute, value ) - return + return None if isinstance(index_selection, NodeOperation) and isinstance( attribute_selection, slice @@ -448,22 +475,24 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node( self._medrecord.select_nodes(index_selection) ) - for node in attributes.keys(): - for attribute in attributes[node].keys(): + for node in attributes: + for attribute in attributes[node]: self._medrecord._medrecord.update_node_attribute( [node], attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance( attribute_selection, slice @@ -476,20 +505,23 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(self._medrecord.nodes) - for node in attributes.keys(): - for attribute in attributes[node].keys(): + for node in attributes: + for attribute in attributes[node]: self._medrecord._medrecord.update_node_attribute( [node], attribute, value ) - return + return None + return None def __delitem__( self, @@ -530,7 +562,8 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.remove_node_attribute( self._medrecord.nodes, @@ -543,7 +576,7 @@ def __delitem__( [index_selection], attribute ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): for attribute in attribute_selection: @@ -551,7 +584,7 @@ def __delitem__( index_selection, attribute ) - return + return None if isinstance(index_selection, NodeOperation) and isinstance( attribute_selection, list @@ -561,7 +594,7 @@ def __delitem__( self._medrecord.select_nodes(index_selection), attribute ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -569,14 +602,15 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.remove_node_attribute( self._medrecord.nodes, attribute ) - return + return None if is_node_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -584,7 +618,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( [index_selection], {} @@ -596,7 +631,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( index_selection, {} @@ -610,7 +646,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.select_nodes(index_selection), {} @@ -627,11 +664,13 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.nodes, {} ) + return None class EdgeIndexer: @@ -702,7 +741,8 @@ def __getitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge(self._medrecord.edges) @@ -720,7 +760,7 @@ def __getitem__( ): attributes = self._medrecord._medrecord.edge(index_selection) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( attribute_selection @@ -729,7 +769,7 @@ def __getitem__( self._medrecord.select_edges(index_selection) ) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, slice) and is_medrecord_attribute( attribute_selection @@ -739,11 +779,12 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(self._medrecord.edges) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if is_edge_index(index_selection) and isinstance(attribute_selection, list): return { @@ -758,7 +799,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, EdgeOperation) and isinstance( @@ -770,7 +811,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, slice) and isinstance(attribute_selection, list): @@ -779,13 +820,14 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(self._medrecord.edges) return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if is_edge_index(index_selection) and isinstance(attribute_selection, slice): @@ -794,7 +836,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge([index_selection])[index_selection] @@ -804,7 +847,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge(index_selection) @@ -816,7 +860,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge( self._medrecord.select_edges(index_selection) @@ -833,9 +878,11 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge(self._medrecord.edges) + return None @overload def __setitem__( @@ -870,19 +917,22 @@ def __setitem__( ) -> None: if is_edge_index(key): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes([key], value) if isinstance(key, list): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes(key, value) if isinstance(key, EdgeOperation): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.select_edges(key), value @@ -890,10 +940,12 @@ def __setitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.edges, value @@ -905,7 +957,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( [index_selection], attribute_selection, value @@ -915,7 +968,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( index_selection, attribute_selection, value @@ -925,7 +979,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( self._medrecord.select_edges(index_selection), @@ -941,10 +996,12 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( self._medrecord.edges, @@ -954,38 +1011,41 @@ def __setitem__( if is_edge_index(index_selection) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( [index_selection], attribute, value ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( index_selection, attribute, value ) - return + return None if isinstance(index_selection, EdgeOperation) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( self._medrecord.select_edges(index_selection), attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -993,17 +1053,19 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( self._medrecord.edges, attribute, value ) - return + return None if is_edge_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -1011,21 +1073,23 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge([index_selection])[ index_selection ] - for attribute in attributes.keys(): + for attribute in attributes: self._medrecord._medrecord.update_edge_attribute( [index_selection], attribute, value ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, slice): if ( @@ -1033,20 +1097,22 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(index_selection) - for edge in attributes.keys(): - for attribute in attributes[edge].keys(): + for edge in attributes: + for attribute in attributes[edge]: self._medrecord._medrecord.update_edge_attribute( [edge], attribute, value ) - return + return None if isinstance(index_selection, EdgeOperation) and isinstance( attribute_selection, slice @@ -1056,22 +1122,24 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge( self._medrecord.select_edges(index_selection) ) - for edge in attributes.keys(): - for attribute in attributes[edge].keys(): + for edge in attributes: + for attribute in attributes[edge]: self._medrecord._medrecord.update_edge_attribute( [edge], attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance( attribute_selection, slice @@ -1084,20 +1152,23 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(self._medrecord.edges) - for edge in attributes.keys(): - for attribute in attributes[edge].keys(): + for edge in attributes: + for attribute in attributes[edge]: self._medrecord._medrecord.update_edge_attribute( [edge], attribute, value ) - return + return None + return None def __delitem__( self, @@ -1138,7 +1209,8 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.remove_edge_attribute( self._medrecord.edges, @@ -1151,7 +1223,7 @@ def __delitem__( [index_selection], attribute ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): for attribute in attribute_selection: @@ -1159,7 +1231,7 @@ def __delitem__( index_selection, attribute ) - return + return None if isinstance(index_selection, EdgeOperation) and isinstance( attribute_selection, list @@ -1169,7 +1241,7 @@ def __delitem__( self._medrecord.select_edges(index_selection), attribute ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -1177,14 +1249,15 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.remove_edge_attribute( self._medrecord.edges, attribute ) - return + return None if is_edge_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -1192,7 +1265,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( [index_selection], {} @@ -1204,7 +1278,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( index_selection, {} @@ -1218,7 +1293,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.select_edges(index_selection), {} @@ -1235,8 +1311,10 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.edges, {} ) + return None diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index 69a5bcae..72205d96 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -94,7 +94,6 @@ def __init__( group_header (str): Header for group column, i.e. 'Group Nodes'. decimal (int): Decimal point to round the float values to. """ - self.data = data self.group_header = group_header self.decimal = decimal @@ -563,11 +562,10 @@ def edges_connecting( (source_node if isinstance(source_node, list) else [source_node]), (target_node if isinstance(target_node, list) else [target_node]), ) - else: - return self._medrecord.edges_connecting_undirected( - (source_node if isinstance(source_node, list) else [source_node]), - (target_node if isinstance(target_node, list) else [target_node]), - ) + return self._medrecord.edges_connecting_undirected( + (source_node if isinstance(source_node, list) else [source_node]), + (target_node if isinstance(target_node, list) else [target_node]), + ) @overload def remove_nodes(self, nodes: NodeIndex) -> Attributes: ... @@ -778,25 +776,24 @@ def add_edges( edges ): return self.add_edges_pandas(edges, group) - elif is_polars_edge_dataframe_input( + if is_polars_edge_dataframe_input(edges) or is_polars_edge_dataframe_input_list( edges - ) or is_polars_edge_dataframe_input_list(edges): + ): return self.add_edges_polars(edges, group) - else: - if is_edge_tuple(edges): - edges = [edges] + if is_edge_tuple(edges): + edges = [edges] - edge_indices = self._medrecord.add_edges(edges) + edge_indices = self._medrecord.add_edges(edges) - if group is None: - return edge_indices + if group is None: + return edge_indices - if not self.contains_group(group): - self.add_group(group) + if not self.contains_group(group): + self.add_group(group) - self.add_edges_to_group(group, edge_indices) + self.add_edges_to_group(group, edge_indices) - return edge_indices + return edge_indices def add_edges_pandas( self, @@ -896,16 +893,15 @@ def add_group( nodes if isinstance(nodes, list) else [nodes], edges if isinstance(edges, list) else [edges], ) - elif nodes is not None: + if nodes is not None: return self._medrecord.add_group( group, nodes if isinstance(nodes, list) else [nodes], None ) - elif edges is not None: + if edges is not None: return self._medrecord.add_group( group, None, edges if isinstance(edges, list) else [edges] ) - else: - return self._medrecord.add_group(group, None, None) + return self._medrecord.add_group(group, None, None) def remove_groups(self, groups: Union[Group, GroupInputList]) -> None: """Removes one or more groups from the MedRecord instance. diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py index 5a6e8385..4a18d80e 100644 --- a/medmodels/medrecord/querying.py +++ b/medmodels/medrecord/querying.py @@ -33,7 +33,7 @@ class NodeOperation: _node_operation: PyNodeOperation - def __init__(self, node_operation: PyNodeOperation): + def __init__(self, node_operation: PyNodeOperation) -> None: self._node_operation = node_operation def logical_and(self, operation: NodeOperation) -> NodeOperation: @@ -1360,12 +1360,9 @@ def has_neighbor_with( return NodeOperation( self._node_operand.has_neighbor_with(operation._node_operation) ) - else: - return NodeOperation( - self._node_operand.has_neighbor_undirected_with( - operation._node_operation - ) - ) + return NodeOperation( + self._node_operand.has_neighbor_undirected_with(operation._node_operation) + ) def attribute(self, attribute: MedRecordAttribute) -> NodeAttributeOperand: """Accesses an NodeAttributeOperand for the specified attribute, allowing for the creation of operations based on node attributes. diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index 2dd28890..5b0cd5fc 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum, auto -from typing import Dict, List, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, overload from medmodels._medmodels import ( PyAttributeDataType, @@ -10,7 +10,9 @@ PySchema, ) from medmodels.medrecord.datatype import DataType -from medmodels.medrecord.types import Group, MedRecordAttribute + +if TYPE_CHECKING: + from medmodels.medrecord.types import Group, MedRecordAttribute class AttributeType(Enum): @@ -20,8 +22,7 @@ class AttributeType(Enum): @staticmethod def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: - """ - Converts a PyAttributeType to an AttributeType. + """Converts a PyAttributeType to an AttributeType. Args: py_attribute_type (PyAttributeType): The PyAttributeType to convert. @@ -31,30 +32,29 @@ def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: """ if py_attribute_type == PyAttributeType.Categorical: return AttributeType.Categorical - elif py_attribute_type == PyAttributeType.Continuous: + if py_attribute_type == PyAttributeType.Continuous: return AttributeType.Continuous - elif py_attribute_type == PyAttributeType.Temporal: + if py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal + return None def _into_pyattributetype(self) -> PyAttributeType: - """ - Converts an AttributeType to a PyAttributeType. + """Converts an AttributeType to a PyAttributeType. Returns: PyAttributeType: The converted PyAttributeType. """ if self == AttributeType.Categorical: return PyAttributeType.Categorical - elif self == AttributeType.Continuous: + if self == AttributeType.Continuous: return PyAttributeType.Continuous - elif self == AttributeType.Temporal: + if self == AttributeType.Temporal: return PyAttributeType.Temporal - else: - raise NotImplementedError("Should never be reached") + msg = "Should never be reached" + raise NotImplementedError(msg) def __repr__(self) -> str: - """ - Returns a string representation of the AttributeType instance. + """Returns a string representation of the AttributeType instance. Returns: str: String representation of the attribute type. @@ -62,8 +62,7 @@ def __repr__(self) -> str: return f"AttributeType.{self.name}" def __str__(self) -> str: - """ - Returns a string representation of the AttributeType instance. + """Returns a string representation of the AttributeType instance. Returns: str: String representation of the attribute type. @@ -71,8 +70,7 @@ def __str__(self) -> str: return self.name def __eq__(self, value: object) -> bool: - """ - Compares the AttributeType instance to another object for equality. + """Compares the AttributeType instance to another object for equality. Args: value (object): The object to compare against. @@ -82,7 +80,7 @@ def __eq__(self, value: object) -> bool: """ if isinstance(value, PyAttributeType): return self._into_pyattributetype() == value - elif isinstance(value, AttributeType): + if isinstance(value, AttributeType): return str(self) == str(value) return False @@ -99,8 +97,7 @@ def __init__( MedRecordAttribute, Tuple[DataType, Optional[AttributeType]] ], ) -> None: - """ - Initializes a new instance of AttributesSchema. + """Initializes a new instance of AttributesSchema. Args: attributes_schema (Dict[MedRecordAttribute, Tuple[DataType, Optional[AttributeType]]]): @@ -113,8 +110,7 @@ def __init__( self._attributes_schema = attributes_schema def __repr__(self) -> str: - """ - Returns a string representation of the AttributesSchema instance. + """Returns a string representation of the AttributesSchema instance. Returns: str: String representation of the attribute schema. @@ -124,8 +120,7 @@ def __repr__(self) -> str: def __getitem__( self, key: MedRecordAttribute ) -> Tuple[DataType, Optional[AttributeType]]: - """ - Gets the data type and optional attribute type for a given MedRecordAttribute. + """Gets the data type and optional attribute type for a given MedRecordAttribute. Args: key (MedRecordAttribute): The attribute for which the data type is @@ -138,8 +133,7 @@ def __getitem__( return self._attributes_schema[key] def __contains__(self, key: MedRecordAttribute) -> bool: - """ - Checks if a given MedRecordAttribute is in the attributes schema. + """Checks if a given MedRecordAttribute is in the attributes schema. Args: key (MedRecordAttribute): The attribute to check. @@ -150,8 +144,7 @@ def __contains__(self, key: MedRecordAttribute) -> bool: return key in self._attributes_schema def __iter__(self): - """ - Returns an iterator over the attributes schema. + """Returns an iterator over the attributes schema. Returns: Iterator: An iterator over the attribute keys. @@ -159,8 +152,7 @@ def __iter__(self): return self._attributes_schema.__iter__() def __len__(self) -> int: - """ - Returns the number of attributes in the schema. + """Returns the number of attributes in the schema. Returns: int: The number of attributes. @@ -168,8 +160,7 @@ def __len__(self) -> int: return len(self._attributes_schema) def __eq__(self, value: object) -> bool: - """ - Compares the AttributesSchema instance to another object for equality. + """Compares the AttributesSchema instance to another object for equality. Args: value (object): The object to compare against. @@ -177,7 +168,7 @@ def __eq__(self, value: object) -> bool: Returns: bool: True if the objects are equal, False otherwise. """ - if not (isinstance(value, AttributesSchema) or isinstance(value, dict)): + if not (isinstance(value, (AttributesSchema, dict))): return False attribute_schema = ( @@ -187,7 +178,7 @@ def __eq__(self, value: object) -> bool: if not attribute_schema.keys() == self._attributes_schema.keys(): return False - for key in self._attributes_schema.keys(): + for key in self._attributes_schema: if ( not isinstance(attribute_schema[key], tuple) or not isinstance( @@ -200,8 +191,7 @@ def __eq__(self, value: object) -> bool: return True def keys(self): - """ - Returns the attribute keys in the schema. + """Returns the attribute keys in the schema. Returns: KeysView: A view object displaying a list of dictionary's keys. @@ -209,8 +199,7 @@ def keys(self): return self._attributes_schema.keys() def values(self): - """ - Returns the attribute values in the schema. + """Returns the attribute values in the schema. Returns: ValuesView: A view object displaying a list of dictionary's values. @@ -218,8 +207,7 @@ def values(self): return self._attributes_schema.values() def items(self): - """ - Returns the attribute key-value pairs in the schema. + """Returns the attribute key-value pairs in the schema. Returns: ItemsView: A set-like object providing a view on D's items. @@ -241,8 +229,7 @@ def get( key: MedRecordAttribute, default: Optional[Tuple[DataType, Optional[AttributeType]]] = None, ) -> Optional[Tuple[DataType, Optional[AttributeType]]]: - """ - Gets the data type and optional attribute type for a given attribute, returning + """Gets the data type and optional attribute type for a given attribute, returning a default value if the attribute is not present. Args: @@ -265,16 +252,15 @@ class GroupSchema: def __init__( self, *, - nodes: Dict[ - MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]] - ] = {}, - edges: Dict[ - MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]] - ] = {}, + nodes: Optional[ + Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] + ] = None, + edges: Optional[ + Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] + ] = None, strict: bool = False, ) -> None: - """ - Initializes a new instance of GroupSchema. + """Initializes a new instance of GroupSchema. Args: nodes (Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]): @@ -289,6 +275,10 @@ def __init__( Returns: None """ + if edges is None: + edges = {} + if nodes is None: + nodes = {} def _convert_input( input: Union[DataType, Tuple[DataType, AttributeType]], @@ -307,8 +297,7 @@ def _convert_input( @classmethod def _from_pygroupschema(cls, group_schema: PyGroupSchema) -> GroupSchema: - """ - Creates a GroupSchema instance from an existing PyGroupSchema. + """Creates a GroupSchema instance from an existing PyGroupSchema. Args: group_schema (PyGroupSchema): The PyGroupSchema instance to convert. @@ -322,8 +311,7 @@ def _from_pygroupschema(cls, group_schema: PyGroupSchema) -> GroupSchema: @property def nodes(self) -> AttributesSchema: - """ - Returns the node attributes in the GroupSchema instance. + """Returns the node attributes in the GroupSchema instance. Returns: AttributesSchema: An AttributesSchema object containing the node attributes @@ -349,8 +337,7 @@ def _convert_node( @property def edges(self) -> AttributesSchema: - """ - Returns the edge attributes in the GroupSchema instance. + """Returns the edge attributes in the GroupSchema instance. Returns: AttributesSchema: An AttributesSchema object containing the edge attributes @@ -376,8 +363,7 @@ def _convert_edge( @property def strict(self) -> Optional[bool]: - """ - Indicates whether the GroupSchema instance is strict. + """Indicates whether the GroupSchema instance is strict. Returns: Optional[bool]: True if the schema is strict, False otherwise. @@ -391,12 +377,11 @@ class Schema: def __init__( self, *, - groups: Dict[Group, GroupSchema] = {}, + groups: Optional[Dict[Group, GroupSchema]] = None, default: Optional[GroupSchema] = None, strict: bool = False, ) -> None: - """ - Initializes a new instance of Schema. + """Initializes a new instance of Schema. Args: groups (Dict[Group, GroupSchema], optional): A dictionary of group names @@ -409,6 +394,8 @@ def __init__( Returns: None """ + if groups is None: + groups = {} if default is not None: self._schema = PySchema( groups={x: groups[x]._group_schema for x in groups}, @@ -423,8 +410,7 @@ def __init__( @classmethod def _from_pyschema(cls, schema: PySchema) -> Schema: - """ - Creates a Schema instance from an existing PySchema. + """Creates a Schema instance from an existing PySchema. Args: schema (PySchema): The PySchema instance to convert. @@ -438,8 +424,7 @@ def _from_pyschema(cls, schema: PySchema) -> Schema: @property def groups(self) -> List[Group]: - """ - Lists all the groups in the Schema instance. + """Lists all the groups in the Schema instance. Returns: List[Group]: A list of groups. @@ -447,8 +432,7 @@ def groups(self) -> List[Group]: return self._schema.groups def group(self, group: Group) -> GroupSchema: - """ - Retrieves the schema for a specific group. + """Retrieves the schema for a specific group. Args: group (Group): The name of the group. @@ -463,8 +447,7 @@ def group(self, group: Group) -> GroupSchema: @property def default(self) -> Optional[GroupSchema]: - """ - Retrieves the default group schema. + """Retrieves the default group schema. Returns: Optional[GroupSchema]: The default group schema if it exists, otherwise @@ -477,8 +460,7 @@ def default(self) -> Optional[GroupSchema]: @property def strict(self) -> Optional[bool]: - """ - Indicates whether the Schema instance is strict. + """Indicates whether the Schema instance is strict. Returns: Optional[bool]: True if the schema is strict, False otherwise. diff --git a/medmodels/medrecord/tests/test_builder.py b/medmodels/medrecord/tests/test_builder.py index 71ef4ce2..bc92ac04 100644 --- a/medmodels/medrecord/tests/test_builder.py +++ b/medmodels/medrecord/tests/test_builder.py @@ -4,56 +4,56 @@ class TestMedRecordBuilder(unittest.TestCase): - def test_add_nodes(self): + def test_add_nodes(self) -> None: builder = mr.MedRecord.builder().add_nodes([("node1", {})]) - self.assertEqual(len(builder._nodes), 1) + assert len(builder._nodes) == 1 builder.add_nodes([("node2", {})], group="group") - self.assertEqual(len(builder._nodes), 2) + assert len(builder._nodes) == 2 medrecord = builder.build() - self.assertEqual(2, len(medrecord.nodes)) - self.assertEqual(1, len(medrecord.groups)) - self.assertEqual(["group"], medrecord.groups_of_node("node2")) + assert len(medrecord.nodes) == 2 + assert len(medrecord.groups) == 1 + assert medrecord.groups_of_node("node2") == ["group"] - def test_add_edges(self): + def test_add_edges(self) -> None: builder = ( mr.MedRecord.builder() .add_nodes([("node1", {}), ("node2", {})]) .add_edges([("node1", "node2", {})]) ) - self.assertEqual(len(builder._edges), 1) + assert len(builder._edges) == 1 builder.add_edges([("node2", "node1", {})], group="group") medrecord = builder.build() - self.assertEqual(2, len(medrecord.nodes)) - self.assertEqual(2, len(medrecord.edges)) - self.assertEqual(1, len(medrecord.groups)) - self.assertEqual(["node2"], medrecord.neighbors("node1")) - self.assertEqual(["group"], medrecord.groups_of_edge(1)) + assert len(medrecord.nodes) == 2 + assert len(medrecord.edges) == 2 + assert len(medrecord.groups) == 1 + assert medrecord.neighbors("node1") == ["node2"] + assert medrecord.groups_of_edge(1) == ["group"] - def test_add_group(self): + def test_add_group(self) -> None: builder = ( mr.MedRecord.builder().add_nodes(("0", {})).add_group("group", nodes=["0"]) ) - self.assertEqual(len(builder._groups), 1) + assert len(builder._groups) == 1 medrecord = builder.build() - self.assertEqual(1, len(medrecord.nodes)) - self.assertEqual(0, len(medrecord.edges)) - self.assertEqual(1, len(medrecord.groups)) - self.assertEqual("group", medrecord.groups[0]) - self.assertEqual(["0"], medrecord.nodes_in_group("group")) + assert len(medrecord.nodes) == 1 + assert len(medrecord.edges) == 0 + assert len(medrecord.groups) == 1 + assert medrecord.groups[0] == "group" + assert medrecord.nodes_in_group("group") == ["0"] - def test_with_schema(self): + def test_with_schema(self) -> None: schema = mr.Schema(default=mr.GroupSchema(nodes={"attribute": mr.Int()})) medrecord = mr.MedRecord.builder().with_schema(schema).build() diff --git a/medmodels/medrecord/tests/test_datatype.py b/medmodels/medrecord/tests/test_datatype.py index 0cf16cd5..6fd3c1c0 100644 --- a/medmodels/medrecord/tests/test_datatype.py +++ b/medmodels/medrecord/tests/test_datatype.py @@ -1,5 +1,7 @@ import unittest +import pytest + import medmodels.medrecord as mr from medmodels._medmodels import ( PyAny, @@ -15,121 +17,115 @@ class TestDataType(unittest.TestCase): - def test_string(self): + def test_string(self) -> None: string = mr.String() - self.assertTrue(isinstance(string._inner(), PyString)) + assert isinstance(string._inner(), PyString) - self.assertEqual("String", str(string)) + assert str(string) == "String" - self.assertEqual("DataType.String", string.__repr__()) + assert string.__repr__() == "DataType.String" - self.assertEqual(mr.String(), mr.String()) - self.assertNotEqual(mr.String(), mr.Int()) + assert mr.String() == mr.String() + assert mr.String() != mr.Int() - def test_int(self): + def test_int(self) -> None: integer = mr.Int() - self.assertTrue(isinstance(integer._inner(), PyInt)) + assert isinstance(integer._inner(), PyInt) - self.assertEqual("Int", str(integer)) + assert str(integer) == "Int" - self.assertEqual("DataType.Int", integer.__repr__()) + assert integer.__repr__() == "DataType.Int" - self.assertEqual(mr.Int(), mr.Int()) - self.assertNotEqual(mr.Int(), mr.String()) + assert mr.Int() == mr.Int() + assert mr.Int() != mr.String() - def test_float(self): + def test_float(self) -> None: float = mr.Float() - self.assertTrue(isinstance(float._inner(), PyFloat)) + assert isinstance(float._inner(), PyFloat) - self.assertEqual("Float", str(float)) + assert str(float) == "Float" - self.assertEqual("DataType.Float", float.__repr__()) + assert float.__repr__() == "DataType.Float" - self.assertEqual(mr.Float(), mr.Float()) - self.assertNotEqual(mr.Float(), mr.String()) + assert mr.Float() == mr.Float() + assert mr.Float() != mr.String() - def test_bool(self): + def test_bool(self) -> None: bool = mr.Bool() - self.assertTrue(isinstance(bool._inner(), PyBool)) + assert isinstance(bool._inner(), PyBool) - self.assertEqual("Bool", str(bool)) + assert str(bool) == "Bool" - self.assertEqual("DataType.Bool", bool.__repr__()) + assert bool.__repr__() == "DataType.Bool" - self.assertEqual(mr.Bool(), mr.Bool()) - self.assertNotEqual(mr.Bool(), mr.String()) + assert mr.Bool() == mr.Bool() + assert mr.Bool() != mr.String() - def test_datetime(self): + def test_datetime(self) -> None: datetime = mr.DateTime() - self.assertTrue(isinstance(datetime._inner(), PyDateTime)) + assert isinstance(datetime._inner(), PyDateTime) - self.assertEqual("DateTime", str(datetime)) + assert str(datetime) == "DateTime" - self.assertEqual("DataType.DateTime", datetime.__repr__()) + assert datetime.__repr__() == "DataType.DateTime" - self.assertEqual(mr.DateTime(), mr.DateTime()) - self.assertNotEqual(mr.DateTime(), mr.String()) + assert mr.DateTime() == mr.DateTime() + assert mr.DateTime() != mr.String() - def test_null(self): + def test_null(self) -> None: null = mr.Null() - self.assertTrue(isinstance(null._inner(), PyNull)) + assert isinstance(null._inner(), PyNull) - self.assertEqual("Null", str(null)) + assert str(null) == "Null" - self.assertEqual("DataType.Null", null.__repr__()) + assert null.__repr__() == "DataType.Null" - self.assertEqual(mr.Null(), mr.Null()) - self.assertNotEqual(mr.Null(), mr.String()) + assert mr.Null() == mr.Null() + assert mr.Null() != mr.String() - def test_any(self): + def test_any(self) -> None: any = mr.Any() - self.assertTrue(isinstance(any._inner(), PyAny)) + assert isinstance(any._inner(), PyAny) - self.assertEqual("Any", str(any)) + assert str(any) == "Any" - self.assertEqual("DataType.Any", any.__repr__()) + assert any.__repr__() == "DataType.Any" - self.assertEqual(mr.Any(), mr.Any()) - self.assertNotEqual(mr.Any(), mr.String()) + assert mr.Any() == mr.Any() + assert mr.Any() != mr.String() - def test_union(self): + def test_union(self) -> None: union = mr.Union(mr.String(), mr.Int()) - self.assertTrue(isinstance(union._inner(), PyUnion)) + assert isinstance(union._inner(), PyUnion) - self.assertEqual("Union(String, Int)", str(union)) + assert str(union) == "Union(String, Int)" - self.assertEqual( - "DataType.Union(DataType.String, DataType.Int)", union.__repr__() - ) + assert union.__repr__() == "DataType.Union(DataType.String, DataType.Int)" union = mr.Union(mr.String(), mr.Int(), mr.Bool()) - self.assertTrue(isinstance(union._inner(), PyUnion)) + assert isinstance(union._inner(), PyUnion) - self.assertEqual("Union(String, Union(Int, Bool))", str(union)) + assert str(union) == "Union(String, Union(Int, Bool))" - self.assertEqual( - "DataType.Union(DataType.String, DataType.Union(DataType.Int, DataType.Bool))", - union.__repr__(), + assert ( + union.__repr__() + == "DataType.Union(DataType.String, DataType.Union(DataType.Int, DataType.Bool))" ) - self.assertEqual( - mr.Union(mr.String(), mr.Int()), mr.Union(mr.String(), mr.Int()) - ) - self.assertNotEqual( - mr.Union(mr.String(), mr.Int()), mr.Union(mr.Int(), mr.String()) - ) + assert mr.Union(mr.String(), mr.Int()) == mr.Union(mr.String(), mr.Int()) + assert mr.Union(mr.String(), mr.Int()) != mr.Union(mr.Int(), mr.String()) - def test_invalid_union(self): - with self.assertRaises(ValueError): + def test_invalid_union(self) -> None: + with pytest.raises(ValueError): mr.Union(mr.String()) - def test_option(self): + def test_option(self) -> None: option = mr.Option(mr.String()) - self.assertTrue(isinstance(option._inner(), PyOption)) + assert isinstance(option._inner(), PyOption) - self.assertEqual("Option(String)", str(option)) + assert str(option) == "Option(String)" - self.assertEqual("DataType.Option(DataType.String)", option.__repr__()) + assert option.__repr__() == "DataType.Option(DataType.String)" - self.assertEqual(mr.Option(mr.String()), mr.Option(mr.String())) - self.assertNotEqual(mr.Option(mr.String()), mr.Option(mr.Int())) + assert mr.Option(mr.String()) == mr.Option(mr.String()) + assert mr.Option(mr.String()) != mr.Option(mr.Int()) diff --git a/medmodels/medrecord/tests/test_indexers.py b/medmodels/medrecord/tests/test_indexers.py index 17492c6d..49799a57 100644 --- a/medmodels/medrecord/tests/test_indexers.py +++ b/medmodels/medrecord/tests/test_indexers.py @@ -1,5 +1,7 @@ import unittest +import pytest + from medmodels import MedRecord from medmodels.medrecord import edge, node @@ -22,354 +24,280 @@ def create_medrecord(): class TestMedRecord(unittest.TestCase): - def test_node_getitem(self): + def test_node_getitem(self) -> None: medrecord = create_medrecord() - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.node[0] - ) + assert medrecord.node[0] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} # Accessing a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.node[50] - self.assertEqual("bar", medrecord.node[0, "foo"]) + assert medrecord.node[0, "foo"] == "bar" # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[0, "test"] - self.assertEqual( - {"foo": "bar", "bar": "foo"}, medrecord.node[0, ["foo", "bar"]] - ) + assert medrecord.node[0, ["foo", "bar"]] == {"foo": "bar", "bar": "foo"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[0, ["foo", "test"]] - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.node[0, :] - ) + assert medrecord.node[0, :] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.node[[0, 1]], - ) + assert medrecord.node[[0, 1]] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.node[[0, 50]] - self.assertEqual( - { - 0: "bar", - 1: "bar", - }, - medrecord.node[[0, 1], "foo"], - ) + assert medrecord.node[[0, 1], "foo"] == {0: "bar", 1: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], "test"] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], "lorem"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.node[[0, 1], ["foo", "bar"]], - ) + assert medrecord.node[[0, 1], ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + } # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], ["foo", "lorem"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.node[[0, 1], :], - ) + assert medrecord.node[[0, 1], :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], ::1] - self.assertEqual( - {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.node[node().index() >= 2], - ) + assert medrecord.node[node().index() >= 2] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Empty query should not fail - self.assertEqual( - {}, - medrecord.node[node().index() > 3], - ) + assert medrecord.node[node().index() > 3] == {} - self.assertEqual( - {2: "bar", 3: "bar"}, - medrecord.node[node().index() >= 2, "foo"], - ) + assert medrecord.node[node().index() >= 2, "foo"] == {2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[node().index() >= 2, "test"] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[node().index() >= 2, ["foo", "bar"]], - ) + assert medrecord.node[node().index() >= 2, ["foo", "bar"]] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[node().index() >= 2, ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[node().index() < 2, ["foo", "lorem"]] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[node().index() >= 2, :], - ) + assert medrecord.node[node().index() >= 2, :] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1] - self.assertEqual( - { - 0: "bar", - 1: "bar", - 2: "bar", - 3: "bar", - }, - medrecord.node[:, "foo"], - ) + assert medrecord.node[:, "foo"] == {0: "bar", 1: "bar", 2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:, ["foo", "bar"]], - ) + assert medrecord.node[:, ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[:, ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:, :], - ) + assert medrecord.node[:, :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, ::1] - def test_node_setitem(self): + def test_node_setitem(self) -> None: # Updating existing attributes medrecord = create_medrecord() medrecord.node[0] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Updating a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.node[50] = {"foo": "bar", "test": "test"} medrecord = create_medrecord() medrecord.node[0, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[0, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[0, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.node[0, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, ::1] = "test" medrecord = create_medrecord() medrecord.node[[0, 1], "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.node[[0, 1], 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], ::1] = "test" medrecord = create_medrecord() medrecord.node[node().index() >= 2] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "test"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "test"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail @@ -377,929 +305,736 @@ def test_node_setitem(self): medrecord = create_medrecord() medrecord.node[node().index() >= 2, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, :] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, ::1] = "test" medrecord = create_medrecord() medrecord.node[:, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.node[1:, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, "foo"] = "test" medrecord = create_medrecord() medrecord.node[:, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.node[1:, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, ["foo", "bar"]] = "test" medrecord = create_medrecord() medrecord.node[:, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.node[1:, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, ::1] = "test" # Adding new attributes medrecord = create_medrecord() medrecord.node[0, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[0, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, + assert medrecord.node[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", }, - medrecord.node[:], - ) + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, + assert medrecord.node[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", }, - medrecord.node[:], - ) + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: { - "foo": "bar", - "bar": "foo", - "test": "test", - "test2": "test", - }, - 3: { - "foo": "bar", - "bar": "test", - "test": "test", - "test2": "test", - }, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } medrecord = create_medrecord() medrecord.node[:, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.node[:, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } # Adding and updating attributes medrecord = create_medrecord() medrecord.node[[0, 1], "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() < 2, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() < 2, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[:, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test"}, + } medrecord = create_medrecord() medrecord.node[:, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, + } - def test_node_delitem(self): + def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[0, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.node[50, "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[0, "test"] medrecord = create_medrecord() del medrecord.node[0, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[0, ["foo", "test"]] medrecord = create_medrecord() del medrecord.node[0, :] - self.assertEqual( - { - 0: {}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): del medrecord.node[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[0, ::1] medrecord = create_medrecord() del medrecord.node[[0, 1], "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.node[[0, 50], "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[[0, 1], "test"] medrecord = create_medrecord() del medrecord.node[[0, 1], ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"lorem": "ipsum"}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[[0, 1], ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[[0, 1], ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.node[[0, 1], :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): del medrecord.node[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[[0, 1], ::1] medrecord = create_medrecord() del medrecord.node[node().index() >= 2, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail del medrecord.node[node().index() > 3, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[node().index() >= 2, "test"] medrecord = create_medrecord() del medrecord.node[node().index() >= 2, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[node().index() >= 2, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[node().index() < 2, ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.node[node().index() >= 2, :] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } + + with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, ::1] medrecord = create_medrecord() del medrecord.node[:, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[::1, "foo"] medrecord = create_medrecord() del medrecord.node[:, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {}, 3: {}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[:, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[::1, ["foo", "bar"]] medrecord = create_medrecord() del medrecord.node[:, :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {}, 1: {}, 2: {}, 3: {}} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:, ::1] - def test_edge_getitem(self): + def test_edge_getitem(self) -> None: medrecord = create_medrecord() - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.edge[0] - ) + assert medrecord.edge[0] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} # Accessing a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge[50] - self.assertEqual("bar", medrecord.edge[0, "foo"]) + assert medrecord.edge[0, "foo"] == "bar" # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[0, "test"] - self.assertEqual( - {"foo": "bar", "bar": "foo"}, medrecord.edge[0, ["foo", "bar"]] - ) + assert medrecord.edge[0, ["foo", "bar"]] == {"foo": "bar", "bar": "foo"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[0, ["foo", "test"]] - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.edge[0, :] - ) + assert medrecord.edge[0, :] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.edge[[0, 1]], - ) + assert medrecord.edge[[0, 1]] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge[[0, 50]] - self.assertEqual( - { - 0: "bar", - 1: "bar", - }, - medrecord.edge[[0, 1], "foo"], - ) + assert medrecord.edge[[0, 1], "foo"] == {0: "bar", 1: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], "test"] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], "lorem"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.edge[[0, 1], ["foo", "bar"]], - ) + assert medrecord.edge[[0, 1], ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + } # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], ["foo", "lorem"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.edge[[0, 1], :], - ) + assert medrecord.edge[[0, 1], :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], ::1] - self.assertEqual( - {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.edge[edge().index() >= 2], - ) + assert medrecord.edge[edge().index() >= 2] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Empty query should not fail - self.assertEqual( - {}, - medrecord.edge[edge().index() > 3], - ) + assert medrecord.edge[edge().index() > 3] == {} - self.assertEqual( - {2: "bar", 3: "bar"}, - medrecord.edge[edge().index() >= 2, "foo"], - ) + assert medrecord.edge[edge().index() >= 2, "foo"] == {2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[edge().index() >= 2, "test"] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[edge().index() >= 2, ["foo", "bar"]], - ) + assert medrecord.edge[edge().index() >= 2, ["foo", "bar"]] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[edge().index() >= 2, ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[edge().index() < 2, ["foo", "lorem"]] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[edge().index() >= 2, :], - ) + assert medrecord.edge[edge().index() >= 2, :] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1] - self.assertEqual( - { - 0: "bar", - 1: "bar", - 2: "bar", - 3: "bar", - }, - medrecord.edge[:, "foo"], - ) + assert medrecord.edge[:, "foo"] == {0: "bar", 1: "bar", 2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:, ["foo", "bar"]], - ) + assert medrecord.edge[:, ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[:, ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:, :], - ) + assert medrecord.edge[:, :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, ::1] - def test_edge_setitem(self): + def test_edge_setitem(self) -> None: # Updating existing attributes medrecord = create_medrecord() medrecord.edge[0] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Updating a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge[50] = {"foo": "bar", "test": "test"} medrecord = create_medrecord() medrecord.edge[0, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[0, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[0, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.edge[0, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, ::1] = "test" medrecord = create_medrecord() medrecord.edge[[0, 1], "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.edge[[0, 1], 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], ::1] = "test" medrecord = create_medrecord() medrecord.edge[edge().index() >= 2] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "test"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "test"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail @@ -1307,577 +1042,458 @@ def test_edge_setitem(self): medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, :] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, ::1] = "test" medrecord = create_medrecord() medrecord.edge[:, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.edge[1:, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, "foo"] = "test" medrecord = create_medrecord() medrecord.edge[:, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.edge[1:, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, ["foo", "bar"]] = "test" medrecord = create_medrecord() medrecord.edge[:, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } + + with pytest.raises(ValueError): medrecord.edge[1:, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, ::1] = "test" # Adding new attributes medrecord = create_medrecord() medrecord.edge[0, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[0, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, + assert medrecord.edge[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", }, - medrecord.edge[:], - ) + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, + assert medrecord.edge[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", }, - medrecord.edge[:], - ) + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: { - "foo": "bar", - "bar": "foo", - "test": "test", - "test2": "test", - }, - 3: { - "foo": "bar", - "bar": "test", - "test": "test", - "test2": "test", - }, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } # Adding and updating attributes medrecord = create_medrecord() medrecord.edge[[0, 1], "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() < 2, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() < 2, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, + } - def test_edge_delitem(self): + def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[0, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.edge[50, "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[0, "test"] medrecord = create_medrecord() del medrecord.edge[0, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[0, ["foo", "test"]] medrecord = create_medrecord() del medrecord.edge[0, :] - self.assertEqual( - { - 0: {}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): del medrecord.edge[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[0, ::1] medrecord = create_medrecord() del medrecord.edge[[0, 1], "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.edge[[0, 50], "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[[0, 1], "test"] medrecord = create_medrecord() del medrecord.edge[[0, 1], ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"lorem": "ipsum"}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[[0, 1], ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[[0, 1], ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.edge[[0, 1], :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } + + with pytest.raises(ValueError): del medrecord.edge[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[[0, 1], ::1] medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail del medrecord.edge[edge().index() > 3, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[edge().index() >= 2, "test"] medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[edge().index() >= 2, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[edge().index() < 2, ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, :] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } + + with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, ::1] medrecord = create_medrecord() del medrecord.edge[:, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[::1, "foo"] medrecord = create_medrecord() del medrecord.edge[:, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {}, 3: {}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[:, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[::1, ["foo", "bar"]] medrecord = create_medrecord() del medrecord.edge[:, :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {}, 1: {}, 2: {}, 3: {}} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:, ::1] diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index bcecd58a..314ca797 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -4,6 +4,7 @@ import pandas as pd import polars as pl +import pytest import medmodels.medrecord as mr from medmodels import MedRecord @@ -73,30 +74,30 @@ def create_medrecord() -> MedRecord: class TestMedRecord(unittest.TestCase): - def test_from_tuples(self): + def test_from_tuples(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 - def test_invalid_from_tuples(self): + def test_invalid_from_tuples(self) -> None: nodes = create_nodes() # Adding an edge pointing to a non-existent node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): MedRecord.from_tuples(nodes, [("0", "50", {})]) # Adding an edge from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): MedRecord.from_tuples(nodes, [("50", "0", {})]) - def test_from_pandas(self): + def test_from_pandas(self) -> None: medrecord = MedRecord.from_pandas( (create_pandas_nodes_dataframe(), "index"), ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 0 medrecord = MedRecord.from_pandas( [ @@ -105,16 +106,16 @@ def test_from_pandas(self): ], ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 0 medrecord = MedRecord.from_pandas( (create_pandas_nodes_dataframe(), "index"), (create_pandas_edges_dataframe(), "source", "target"), ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_pandas( [ @@ -124,8 +125,8 @@ def test_from_pandas(self): (create_pandas_edges_dataframe(), "source", "target"), ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_pandas( (create_pandas_nodes_dataframe(), "index"), @@ -135,8 +136,8 @@ def test_from_pandas(self): ], ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 4 medrecord = MedRecord.from_pandas( [ @@ -149,10 +150,10 @@ def test_from_pandas(self): ], ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 - def test_from_polars(self): + def test_from_polars(self) -> None: nodes = pl.from_pandas(create_pandas_nodes_dataframe()) second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) edges = pl.from_pandas(create_pandas_edges_dataframe()) @@ -160,83 +161,83 @@ def test_from_polars(self): medrecord = MedRecord.from_polars((nodes, "index"), (edges, "source", "target")) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_polars( [(nodes, "index"), (second_nodes, "index")], (edges, "source", "target") ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_polars( (nodes, "index"), [(edges, "source", "target"), (second_edges, "source", "target")], ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 4 medrecord = MedRecord.from_polars( [(nodes, "index"), (second_nodes, "index")], [(edges, "source", "target"), (second_edges, "source", "target")], ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 - def test_invalid_from_polars(self): + def test_invalid_from_polars(self) -> None: nodes = pl.from_pandas(create_pandas_nodes_dataframe()) second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) edges = pl.from_pandas(create_pandas_edges_dataframe()) second_edges = pl.from_pandas(create_second_pandas_edges_dataframe()) # Providing the wrong node index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars((nodes, "invalid"), (edges, "source", "target")) # Providing the wrong node index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars( [(nodes, "index"), (second_nodes, "invalid")], (edges, "source", "target"), ) # Providing the wrong source index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars((nodes, "index"), (edges, "invalid", "target")) # Providing the wrong source index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars( (nodes, "index"), [(edges, "source", "target"), (second_edges, "invalid", "target")], ) # Providing the wrong target index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars((nodes, "index"), (edges, "source", "invalid")) # Providing the wrong target index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars( (nodes, "index"), [(edges, "source", "target"), (edges, "source", "invalid")], ) - def test_from_example_dataset(self): + def test_from_example_dataset(self) -> None: medrecord = MedRecord.from_example_dataset() - self.assertEqual(73, medrecord.node_count()) - self.assertEqual(160, medrecord.edge_count()) + assert medrecord.node_count() == 73 + assert medrecord.edge_count() == 160 - self.assertEqual(25, len(medrecord.nodes_in_group("diagnosis"))) - self.assertEqual(19, len(medrecord.nodes_in_group("drug"))) - self.assertEqual(5, len(medrecord.nodes_in_group("patient"))) - self.assertEqual(24, len(medrecord.nodes_in_group("procedure"))) + assert len(medrecord.nodes_in_group("diagnosis")) == 25 + assert len(medrecord.nodes_in_group("drug")) == 19 + assert len(medrecord.nodes_in_group("patient")) == 5 + assert len(medrecord.nodes_in_group("procedure")) == 24 - def test_ron(self): + def test_ron(self) -> None: medrecord = create_medrecord() with tempfile.NamedTemporaryFile() as f: @@ -244,10 +245,10 @@ def test_ron(self): loaded_medrecord = MedRecord.from_ron(f.name) - self.assertEqual(medrecord.node_count(), loaded_medrecord.node_count()) - self.assertEqual(medrecord.edge_count(), loaded_medrecord.edge_count()) + assert medrecord.node_count() == loaded_medrecord.node_count() + assert medrecord.edge_count() == loaded_medrecord.edge_count() - def test_schema(self): + def test_schema(self) -> None: schema = mr.Schema( groups={ "group": mr.GroupSchema( @@ -292,210 +293,203 @@ def test_schema(self): with self.assertRaises(ValueError): medrecord.add_edges_to_group("group", edge_index) - def test_nodes(self): + def test_nodes(self) -> None: medrecord = create_medrecord() nodes = [x[0] for x in create_nodes()] for node in medrecord.nodes: - self.assertTrue(node in nodes) + assert node in nodes - def test_edges(self): + def test_edges(self) -> None: medrecord = create_medrecord() edges = list(range(len(create_edges()))) for edge in medrecord.edges: - self.assertTrue(edge in edges) + assert edge in edges - def test_groups(self): + def test_groups(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual(["0"], medrecord.groups) + assert medrecord.groups == ["0"] - def test_group(self): + def test_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual({"nodes": [], "edges": []}, medrecord.group("0")) + assert medrecord.group("0") == {"nodes": [], "edges": []} medrecord.add_group("1", ["0"], [0]) - self.assertEqual({"nodes": ["0"], "edges": [0]}, medrecord.group("1")) + assert medrecord.group("1") == {"nodes": ["0"], "edges": [0]} - self.assertEqual( - {"0": {"nodes": [], "edges": []}, "1": {"nodes": ["0"], "edges": [0]}}, - medrecord.group(["0", "1"]), - ) + assert medrecord.group(["0", "1"]) == { + "0": {"nodes": [], "edges": []}, + "1": {"nodes": ["0"], "edges": [0]}, + } - def test_invalid_group(self): + def test_invalid_group(self) -> None: medrecord = create_medrecord() # Querying a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.group("0") medrecord.add_group("1", ["0"]) # Querying a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.group(["0", "50"]) - def test_outgoing_edges(self): + def test_outgoing_edges(self) -> None: medrecord = create_medrecord() edges = medrecord.outgoing_edges("0") - self.assertEqual( - sorted([0, 3]), - sorted(edges), - ) + assert sorted([0, 3]) == sorted(edges) edges = medrecord.outgoing_edges(["0", "1"]) - self.assertEqual( - {"0": sorted([0, 3]), "1": [1, 2]}, - {key: sorted(value) for (key, value) in edges.items()}, - ) + assert {key: sorted(value) for key, value in edges.items()} == { + "0": sorted([0, 3]), + "1": [1, 2], + } edges = medrecord.outgoing_edges(node_select().index().is_in(["0", "1"])) - self.assertEqual( - {"0": sorted([0, 3]), "1": [1, 2]}, - {key: sorted(value) for (key, value) in edges.items()}, - ) + assert {key: sorted(value) for key, value in edges.items()} == { + "0": sorted([0, 3]), + "1": [1, 2], + } - def test_invalid_outgoing_edges(self): + def test_invalid_outgoing_edges(self) -> None: medrecord = create_medrecord() # Querying outgoing edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.outgoing_edges("50") # Querying outgoing edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.outgoing_edges(["0", "50"]) - def test_incoming_edges(self): + def test_incoming_edges(self) -> None: medrecord = create_medrecord() edges = medrecord.incoming_edges("1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.incoming_edges(["1", "2"]) - self.assertEqual({"1": [0], "2": [2]}, edges) + assert edges == {"1": [0], "2": [2]} edges = medrecord.incoming_edges(node_select().index().is_in(["1", "2"])) - self.assertEqual({"1": [0], "2": [2]}, edges) + assert edges == {"1": [0], "2": [2]} - def test_invalid_incoming_edges(self): + def test_invalid_incoming_edges(self) -> None: medrecord = create_medrecord() # Querying incoming edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.incoming_edges("50") # Querying incoming edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.incoming_edges(["0", "50"]) - def test_edge_endpoints(self): + def test_edge_endpoints(self) -> None: medrecord = create_medrecord() endpoints = medrecord.edge_endpoints(0) - self.assertEqual(("0", "1"), endpoints) + assert endpoints == ("0", "1") endpoints = medrecord.edge_endpoints([0, 1]) - self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) + assert endpoints == {0: ("0", "1"), 1: ("1", "0")} endpoints = medrecord.edge_endpoints(edge_select().index().is_in([0, 1])) - self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) + assert endpoints == {0: ("0", "1"), 1: ("1", "0")} - def test_invalid_edge_endpoints(self): + def test_invalid_edge_endpoints(self) -> None: medrecord = create_medrecord() # Querying endpoints of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge_endpoints(50) # Querying endpoints of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge_endpoints([0, 50]) - def test_edges_connecting(self): + def test_edges_connecting(self) -> None: medrecord = create_medrecord() edges = medrecord.edges_connecting("0", "1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.edges_connecting(["0", "1"], "1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.edges_connecting(node_select().index().is_in(["0", "1"]), "1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.edges_connecting("0", ["1", "3"]) - self.assertEqual(sorted([0, 3]), sorted(edges)) + assert sorted([0, 3]) == sorted(edges) edges = medrecord.edges_connecting("0", node_select().index().is_in(["1", "3"])) - self.assertEqual(sorted([0, 3]), sorted(edges)) + assert sorted([0, 3]) == sorted(edges) edges = medrecord.edges_connecting(["0", "1"], ["1", "2", "3"]) - self.assertEqual(sorted([0, 2, 3]), sorted(edges)) + assert sorted([0, 2, 3]) == sorted(edges) edges = medrecord.edges_connecting( node_select().index().is_in(["0", "1"]), node_select().index().is_in(["1", "2", "3"]), ) - self.assertEqual(sorted([0, 2, 3]), sorted(edges)) + assert sorted([0, 2, 3]) == sorted(edges) edges = medrecord.edges_connecting("0", "1", directed=False) - self.assertEqual([0, 1], sorted(edges)) + assert sorted(edges) == [0, 1] def test_remove_nodes(self): medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 attributes = medrecord.remove_nodes("0") - self.assertEqual(3, medrecord.node_count()) - self.assertEqual(create_nodes()[0][1], attributes) + assert medrecord.node_count() == 3 + assert create_nodes()[0][1] == attributes attributes = medrecord.remove_nodes(["1", "2"]) - self.assertEqual(1, medrecord.node_count()) - self.assertEqual( - {"1": create_nodes()[1][1], "2": create_nodes()[2][1]}, attributes - ) + assert medrecord.node_count() == 1 + assert attributes == {"1": create_nodes()[1][1], "2": create_nodes()[2][1]} medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 attributes = medrecord.remove_nodes(node_select().index().is_in(["0", "1"])) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual( - {"0": create_nodes()[0][1], "1": create_nodes()[1][1]}, attributes - ) + assert medrecord.node_count() == 2 + assert attributes == {"0": create_nodes()[0][1], "1": create_nodes()[1][1]} def test_invalid_remove_nodes(self): medrecord = create_medrecord() @@ -508,14 +502,14 @@ def test_invalid_remove_nodes(self): with self.assertRaises(IndexError): medrecord.remove_nodes(["0", "50"]) - def test_add_nodes(self): + def test_add_nodes(self) -> None: medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes(create_nodes()) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Adding node tuple medrecord = MedRecord() @@ -556,11 +550,11 @@ def test_add_nodes(self): # Adding pandas dataframe medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes((create_pandas_nodes_dataframe(), "index")) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 # Adding pandas dataframe to a group medrecord = MedRecord() @@ -573,13 +567,13 @@ def test_add_nodes(self): # Adding polars dataframe medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 nodes = pl.from_pandas(create_pandas_nodes_dataframe()) medrecord.add_nodes((nodes, "index")) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 # Adding polars dataframe to a group medrecord = MedRecord() @@ -592,7 +586,7 @@ def test_add_nodes(self): # Adding multiple pandas dataframes medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes( [ @@ -601,7 +595,7 @@ def test_add_nodes(self): ] ) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Adding multiple pandas dataframes to a group medrecord = MedRecord() @@ -641,7 +635,7 @@ def test_add_nodes(self): second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes( [ @@ -650,7 +644,7 @@ def test_add_nodes(self): ] ) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Adding multiple polars dataframes to a group medrecord = MedRecord() @@ -671,29 +665,29 @@ def test_add_nodes(self): def test_invalid_add_nodes(self): medrecord = create_medrecord() - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_nodes(create_nodes()) - def test_add_nodes_pandas(self): + def test_add_nodes_pandas(self) -> None: medrecord = MedRecord() nodes = (create_pandas_nodes_dataframe(), "index") - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes_pandas(nodes) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 medrecord = MedRecord() second_nodes = (create_second_pandas_nodes_dataframe(), "index") - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes_pandas([nodes, second_nodes]) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Trying with the group argument medrecord = MedRecord() @@ -724,21 +718,21 @@ def test_add_nodes_polars(self): nodes = pl.from_pandas(create_pandas_nodes_dataframe()) - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes_polars((nodes, "index")) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 medrecord = MedRecord() second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes_polars([(nodes, "index"), (second_nodes, "index")]) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Trying with the group argument medrecord = MedRecord() @@ -773,36 +767,36 @@ def test_invalid_add_nodes_polars(self): second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) # Adding a nodes dataframe with the wrong index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_nodes_polars((nodes, "invalid")) # Adding a nodes dataframe with the wrong index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_nodes_polars([(nodes, "index"), (second_nodes, "invalid")]) def test_remove_edges(self): medrecord = create_medrecord() - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 attributes = medrecord.remove_edges(0) - self.assertEqual(3, medrecord.edge_count()) - self.assertEqual(create_edges()[0][2], attributes) + assert medrecord.edge_count() == 3 + assert create_edges()[0][2] == attributes attributes = medrecord.remove_edges([1, 2]) - self.assertEqual(1, medrecord.edge_count()) - self.assertEqual({1: create_edges()[1][2], 2: create_edges()[2][2]}, attributes) + assert medrecord.edge_count() == 1 + assert attributes == {1: create_edges()[1][2], 2: create_edges()[2][2]} medrecord = create_medrecord() - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 attributes = medrecord.remove_edges(edge_select().index().is_in([0, 1])) - self.assertEqual(2, medrecord.edge_count()) - self.assertEqual({0: create_edges()[0][2], 1: create_edges()[1][2]}, attributes) + assert medrecord.edge_count() == 2 + assert attributes == {0: create_edges()[0][2], 1: create_edges()[1][2]} def test_invalid_remove_edges(self): medrecord = create_medrecord() @@ -811,18 +805,18 @@ def test_invalid_remove_edges(self): with self.assertRaises(IndexError): medrecord.remove_edges(50) - def test_add_edges(self): + def test_add_edges(self) -> None: medrecord = MedRecord() nodes = create_nodes() medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges(create_edges()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 # Adding single edge tuple medrecord = create_medrecord() @@ -855,11 +849,11 @@ def test_add_edges(self): medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges((create_pandas_edges_dataframe(), "source", "target")) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 # Adding pandas dataframe to a group medrecord = MedRecord() @@ -876,13 +870,13 @@ def test_add_edges(self): medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 edges = pl.from_pandas(create_pandas_edges_dataframe()) medrecord.add_edges((edges, "source", "target")) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 # Adding polars dataframe to a group medrecord = MedRecord() @@ -899,7 +893,7 @@ def test_add_edges(self): medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges( [ @@ -908,7 +902,7 @@ def test_add_edges(self): ] ) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 # Adding multiple pandas dataframe to a group medrecord = MedRecord() @@ -933,7 +927,7 @@ def test_add_edges(self): medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 second_edges = pl.from_pandas(create_second_pandas_edges_dataframe()) @@ -944,7 +938,7 @@ def test_add_edges(self): ] ) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 # Adding multiple polars dataframe to a group medrecord = MedRecord() @@ -988,11 +982,11 @@ def test_add_edges_pandas(self): edges = (create_pandas_edges_dataframe(), "source", "target") - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges(edges) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 # Adding to a group medrecord = MedRecord() @@ -1026,11 +1020,11 @@ def test_add_edges_polars(self): edges = pl.from_pandas(create_pandas_edges_dataframe()) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges_polars((edges, "source", "target")) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 # Adding to a group medrecord = MedRecord() @@ -1067,33 +1061,33 @@ def test_invalid_add_edges_polars(self): edges = pl.from_pandas(create_pandas_edges_dataframe()) # Providing the wrong source index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_edges_polars((edges, "invalid", "target")) # Providing the wrong target index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_edges_polars((edges, "source", "invalid")) - def test_add_group(self): + def test_add_group(self) -> None: medrecord = create_medrecord() - self.assertEqual(0, medrecord.group_count()) + assert medrecord.group_count() == 0 medrecord.add_group("0") - self.assertEqual(1, medrecord.group_count()) + assert medrecord.group_count() == 1 medrecord.add_group("1", "0", 0) - self.assertEqual(2, medrecord.group_count()) - self.assertEqual({"nodes": ["0"], "edges": [0]}, medrecord.group("1")) + assert medrecord.group_count() == 2 + assert medrecord.group("1") == {"nodes": ["0"], "edges": [0]} medrecord.add_group("2", ["0", "1"], [0, 1]) - self.assertEqual(3, medrecord.group_count()) + assert medrecord.group_count() == 3 nodes_and_edges = medrecord.group("2") - self.assertEqual(sorted(["0", "1"]), sorted(nodes_and_edges["nodes"])) - self.assertEqual(sorted([0, 1]), sorted(nodes_and_edges["edges"])) + assert sorted(["0", "1"]) == sorted(nodes_and_edges["nodes"]) + assert sorted([0, 1]) == sorted(nodes_and_edges["edges"]) medrecord.add_group( "3", @@ -1101,38 +1095,38 @@ def test_add_group(self): edge_select().index().is_in([0, 1]), ) - self.assertEqual(4, medrecord.group_count()) + assert medrecord.group_count() == 4 nodes_and_edges = medrecord.group("3") - self.assertEqual(sorted(["0", "1"]), sorted(nodes_and_edges["nodes"])) - self.assertEqual(sorted([0, 1]), sorted(nodes_and_edges["edges"])) + assert sorted(["0", "1"]) == sorted(nodes_and_edges["nodes"]) + assert sorted([0, 1]) == sorted(nodes_and_edges["edges"]) - def test_invalid_add_group(self): + def test_invalid_add_group(self) -> None: medrecord = create_medrecord() # Adding a group with a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_group("0", "50") # Adding an already existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_group("0", ["0", "50"]) medrecord.add_group("0", "0") # Adding an already existing group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0") # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0", "0") # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0", ["1", "0"]) # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0", node_select().index() == "0") def test_remove_groups(self): @@ -1140,11 +1134,11 @@ def test_remove_groups(self): medrecord.add_group("0") - self.assertEqual(1, medrecord.group_count()) + assert medrecord.group_count() == 1 medrecord.remove_groups("0") - self.assertEqual(0, medrecord.group_count()) + assert medrecord.group_count() == 0 def test_invalid_remove_groups(self): medrecord = create_medrecord() @@ -1158,25 +1152,19 @@ def test_add_nodes_to_group(self): medrecord.add_group("0") - self.assertEqual([], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == [] medrecord.add_nodes_to_group("0", "0") - self.assertEqual(["0"], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == ["0"] medrecord.add_nodes_to_group("0", ["1", "2"]) - self.assertEqual( - sorted(["0", "1", "2"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1", "2"]) == sorted(medrecord.nodes_in_group("0")) medrecord.add_nodes_to_group("0", node_select().index() == "3") - self.assertEqual( - sorted(["0", "1", "2", "3"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1", "2", "3"]) == sorted(medrecord.nodes_in_group("0")) def test_invalid_add_nodes_to_group(self): medrecord = create_medrecord() @@ -1216,25 +1204,19 @@ def test_add_edges_to_group(self): medrecord.add_group("0") - self.assertEqual([], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [] medrecord.add_edges_to_group("0", 0) - self.assertEqual([0], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [0] medrecord.add_edges_to_group("0", [1, 2]) - self.assertEqual( - sorted([0, 1, 2]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1, 2]) == sorted(medrecord.edges_in_group("0")) medrecord.add_edges_to_group("0", edge_select().index() == 3) - self.assertEqual( - sorted([0, 1, 2, 3]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1, 2, 3]) == sorted(medrecord.edges_in_group("0")) def test_invalid_add_edges_to_group(self): medrecord = create_medrecord() @@ -1274,36 +1256,27 @@ def test_remove_nodes_from_group(self): medrecord.add_group("0", ["0", "1"]) - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) medrecord.remove_nodes_from_group("0", "1") - self.assertEqual(["0"], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == ["0"] medrecord.add_nodes_to_group("0", "1") - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) medrecord.remove_nodes_from_group("0", ["0", "1"]) - self.assertEqual([], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == [] medrecord.add_nodes_to_group("0", ["0", "1"]) - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) medrecord.remove_nodes_from_group("0", node_select().index().is_in(["0", "1"])) - self.assertEqual([], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == [] def test_invalid_remove_nodes_from_group(self): medrecord = create_medrecord() @@ -1335,36 +1308,27 @@ def test_remove_edges_from_group(self): medrecord.add_group("0", edges=[0, 1]) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) medrecord.remove_edges_from_group("0", 1) - self.assertEqual([0], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [0] medrecord.add_edges_to_group("0", 1) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) medrecord.remove_edges_from_group("0", [0, 1]) - self.assertEqual([], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [] medrecord.add_edges_to_group("0", [0, 1]) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) medrecord.remove_edges_from_group("0", edge_select().index().is_in([0, 1])) - self.assertEqual([], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [] def test_invalid_remove_edges_from_group(self): medrecord = create_medrecord() @@ -1391,221 +1355,209 @@ def test_invalid_remove_edges_from_group(self): with self.assertRaises(IndexError): medrecord.remove_edges_from_group("0", [0, 50]) - def test_nodes_in_group(self): + def test_nodes_in_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) - def test_invalid_nodes_in_group(self): + def test_invalid_nodes_in_group(self) -> None: medrecord = create_medrecord() # Querying nodes in a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.nodes_in_group("50") - def test_edges_in_group(self): + def test_edges_in_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) - def test_invalid_edges_in_group(self): + def test_invalid_edges_in_group(self) -> None: medrecord = create_medrecord() # Querying edges in a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edges_in_group("50") - def test_groups_of_node(self): + def test_groups_of_node(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) - self.assertEqual(["0"], medrecord.groups_of_node("0")) + assert medrecord.groups_of_node("0") == ["0"] - self.assertEqual({"0": ["0"], "1": ["0"]}, medrecord.groups_of_node(["0", "1"])) + assert medrecord.groups_of_node(["0", "1"]) == {"0": ["0"], "1": ["0"]} - self.assertEqual( - {"0": ["0"], "1": ["0"]}, - medrecord.groups_of_node(node_select().index().is_in(["0", "1"])), - ) + assert medrecord.groups_of_node(node_select().index().is_in(["0", "1"])) == { + "0": ["0"], + "1": ["0"], + } - def test_invalid_groups_of_node(self): + def test_invalid_groups_of_node(self) -> None: medrecord = create_medrecord() # Querying groups of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_node("50") # Querying groups of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_node(["0", "50"]) - def test_groups_of_edge(self): + def test_groups_of_edge(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) - self.assertEqual(["0"], medrecord.groups_of_edge(0)) + assert medrecord.groups_of_edge(0) == ["0"] - self.assertEqual({0: ["0"], 1: ["0"]}, medrecord.groups_of_edge([0, 1])) + assert medrecord.groups_of_edge([0, 1]) == {0: ["0"], 1: ["0"]} - self.assertEqual( - {0: ["0"], 1: ["0"]}, - medrecord.groups_of_edge(edge_select().index().is_in([0, 1])), - ) + assert medrecord.groups_of_edge(edge_select().index().is_in([0, 1])) == { + 0: ["0"], + 1: ["0"], + } - def test_invalid_groups_of_edge(self): + def test_invalid_groups_of_edge(self) -> None: medrecord = create_medrecord() # Querying groups of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_edge(50) # Querying groups of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_edge([0, 50]) - def test_node_count(self): + def test_node_count(self) -> None: medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes([("0", {})]) - self.assertEqual(1, medrecord.node_count()) + assert medrecord.node_count() == 1 - def test_edge_count(self): + def test_edge_count(self) -> None: medrecord = MedRecord() medrecord.add_nodes(("0", {})) medrecord.add_nodes(("1", {})) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges(("0", "1", {})) - self.assertEqual(1, medrecord.edge_count()) + assert medrecord.edge_count() == 1 - def test_group_count(self): + def test_group_count(self) -> None: medrecord = create_medrecord() - self.assertEqual(0, medrecord.group_count()) + assert medrecord.group_count() == 0 medrecord.add_group("0") - self.assertEqual(1, medrecord.group_count()) + assert medrecord.group_count() == 1 - def test_contains_node(self): + def test_contains_node(self) -> None: medrecord = create_medrecord() - self.assertTrue(medrecord.contains_node("0")) + assert medrecord.contains_node("0") - self.assertFalse(medrecord.contains_node("50")) + assert not medrecord.contains_node("50") - def test_contains_edge(self): + def test_contains_edge(self) -> None: medrecord = create_medrecord() - self.assertTrue(medrecord.contains_edge(0)) + assert medrecord.contains_edge(0) - self.assertFalse(medrecord.contains_edge(50)) + assert not medrecord.contains_edge(50) - def test_contains_group(self): + def test_contains_group(self) -> None: medrecord = create_medrecord() - self.assertFalse(medrecord.contains_group("0")) + assert not medrecord.contains_group("0") medrecord.add_group("0") - self.assertTrue(medrecord.contains_group("0")) + assert medrecord.contains_group("0") - def test_neighbors(self): + def test_neighbors(self) -> None: medrecord = create_medrecord() neighbors = medrecord.neighbors("0") - self.assertEqual( - sorted(["1", "3"]), - sorted(neighbors), - ) + assert sorted(["1", "3"]) == sorted(neighbors) neighbors = medrecord.neighbors(["0", "1"]) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } neighbors = medrecord.neighbors(node_select().index().is_in(["0", "1"])) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } neighbors = medrecord.neighbors("0", directed=False) - self.assertEqual( - sorted(["1", "3"]), - sorted(neighbors), - ) + assert sorted(["1", "3"]) == sorted(neighbors) neighbors = medrecord.neighbors(["0", "1"], directed=False) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } neighbors = medrecord.neighbors( node_select().index().is_in(["0", "1"]), directed=False ) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } - def test_invalid_neighbors(self): + def test_invalid_neighbors(self) -> None: medrecord = create_medrecord() # Querying neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors("50") # Querying neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors(["0", "50"]) # Querying undirected neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors("50", directed=False) # Querying undirected neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors(["0", "50"], directed=False) - def test_clear(self): + def test_clear(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) - self.assertEqual(0, medrecord.group_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 + assert medrecord.group_count() == 0 medrecord.clear() - self.assertEqual(0, medrecord.node_count()) - self.assertEqual(0, medrecord.edge_count()) - self.assertEqual(0, medrecord.group_count()) + assert medrecord.node_count() == 0 + assert medrecord.edge_count() == 0 + assert medrecord.group_count() == 0 def test_clone(self): medrecord = create_medrecord() diff --git a/medmodels/medrecord/tests/test_querying.py b/medmodels/medrecord/tests/test_querying.py index 808961ab..c0464efa 100644 --- a/medmodels/medrecord/tests/test_querying.py +++ b/medmodels/medrecord/tests/test_querying.py @@ -38,1133 +38,881 @@ def create_medrecord() -> MedRecord: class TestMedRecord(unittest.TestCase): - def test_select_nodes_node(self): + def test_select_nodes_node(self) -> None: medrecord = create_medrecord() medrecord.add_group("test", ["0"]) # Node in group - self.assertEqual(["0"], medrecord.select_nodes(node().in_group("test"))) + assert medrecord.select_nodes(node().in_group("test")) == ["0"] # Node has attribute - self.assertEqual(["0"], medrecord.select_nodes(node().has_attribute("lorem"))) + assert medrecord.select_nodes(node().has_attribute("lorem")) == ["0"] # Node has outgoing edge with - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().has_outgoing_edge_with(edge().index().equal(0)) - ), - ) + assert medrecord.select_nodes( + node().has_outgoing_edge_with(edge().index().equal(0)) + ) == ["0"] # Node has incoming edge with - self.assertEqual( - ["1"], - medrecord.select_nodes( - node().has_incoming_edge_with(edge().index().equal(0)) - ), - ) + assert medrecord.select_nodes( + node().has_incoming_edge_with(edge().index().equal(0)) + ) == ["1"] # Node has edge with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes(node().has_edge_with(edge().index().equal(0))) - ), + assert sorted(["0", "1"]) == sorted( + medrecord.select_nodes(node().has_edge_with(edge().index().equal(0))) ) # Node has neighbor with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("2")) - ) - ), - ) - self.assertEqual( - sorted(["0"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=True) - ) - ), + assert sorted(["0", "1"]) == sorted( + medrecord.select_nodes(node().has_neighbor_with(node().index().equal("2"))) + ) + assert sorted(["0"]) == sorted( + medrecord.select_nodes( + node().has_neighbor_with(node().index().equal("1"), directed=True) + ) ) # Node has neighbor with - self.assertEqual( - sorted(["0", "2"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=False) - ) - ), + assert sorted(["0", "2"]) == sorted( + medrecord.select_nodes( + node().has_neighbor_with(node().index().equal("1"), directed=False) + ) ) - def test_select_nodes_node_index(self): + def test_select_nodes_node_index(self) -> None: medrecord = create_medrecord() # Index greater - self.assertEqual( - sorted(["2", "3"]), - sorted(medrecord.select_nodes(node().index().greater("1"))), + assert sorted(["2", "3"]) == sorted( + medrecord.select_nodes(node().index().greater("1")) ) # Index less - self.assertEqual( - sorted(["0", "1"]), sorted(medrecord.select_nodes(node().index().less("2"))) + assert sorted(["0", "1"]) == sorted( + medrecord.select_nodes(node().index().less("2")) ) # Index greater or equal - self.assertEqual( - sorted(["1", "2", "3"]), - sorted(medrecord.select_nodes(node().index().greater_or_equal("1"))), + assert sorted(["1", "2", "3"]) == sorted( + medrecord.select_nodes(node().index().greater_or_equal("1")) ) # Index less or equal - self.assertEqual( - sorted(["0", "1", "2"]), - sorted(medrecord.select_nodes(node().index().less_or_equal("2"))), + assert sorted(["0", "1", "2"]) == sorted( + medrecord.select_nodes(node().index().less_or_equal("2")) ) # Index equal - self.assertEqual(["1"], medrecord.select_nodes(node().index().equal("1"))) + assert medrecord.select_nodes(node().index().equal("1")) == ["1"] # Index not equal - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_equal("1"))), + assert sorted(["0", "2", "3"]) == sorted( + medrecord.select_nodes(node().index().not_equal("1")) ) # Index in - self.assertEqual(["1"], medrecord.select_nodes(node().index().is_in(["1"]))) + assert medrecord.select_nodes(node().index().is_in(["1"])) == ["1"] # Index not in - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_in(["1"]))), + assert sorted(["0", "2", "3"]) == sorted( + medrecord.select_nodes(node().index().not_in(["1"])) ) # Index starts with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().starts_with("1")), - ) + assert medrecord.select_nodes(node().index().starts_with("1")) == ["1"] # Index ends with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().ends_with("1")), - ) + assert medrecord.select_nodes(node().index().ends_with("1")) == ["1"] # Index contains - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().contains("1")), - ) + assert medrecord.select_nodes(node().index().contains("1")) == ["1"] - def test_select_nodes_node_attribute(self): + def test_select_nodes_node_attribute(self) -> None: medrecord = create_medrecord() # Attribute greater - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").greater("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") > "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").greater("ipsum")) == [] + assert medrecord.select_nodes(node().attribute("lorem") > "ipsum") == [] # Attribute less - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").less("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") < "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").less("ipsum")) == [] + assert medrecord.select_nodes(node().attribute("lorem") < "ipsum") == [] # Attribute greater or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").greater_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") >= "ipsum") - ) + assert medrecord.select_nodes( + node().attribute("lorem").greater_or_equal("ipsum") + ) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") >= "ipsum") == ["0"] # Attribute less or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").less_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") <= "ipsum") - ) + assert medrecord.select_nodes( + node().attribute("lorem").less_or_equal("ipsum") + ) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") <= "ipsum") == ["0"] # Attribute equal - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").equal("ipsum")) - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") == "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").equal("ipsum")) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") == "ipsum") == ["0"] # Attribute not equal - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") != "ipsum") + assert ( + medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) == [] ) + assert medrecord.select_nodes(node().attribute("lorem") != "ipsum") == [] # Attribute in - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) - ) + assert medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) == [ + "0" + ] # Attribute not in - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) - ) + assert medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) == [] # Attribute starts with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").starts_with("ip")), - ) + assert medrecord.select_nodes(node().attribute("lorem").starts_with("ip")) == [ + "0" + ] # Attribute ends with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").ends_with("um")), - ) + assert medrecord.select_nodes(node().attribute("lorem").ends_with("um")) == [ + "0" + ] # Attribute contains - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").contains("su")), - ) + assert medrecord.select_nodes(node().attribute("lorem").contains("su")) == ["0"] # Attribute compare to attribute - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem")) - ), - ) - self.assertEqual( - [], + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem")) + ) == ["0"] + assert ( medrecord.select_nodes( node().attribute("lorem").not_equal(node().attribute("lorem")) - ), + ) + == [] ) # Attribute compare to attribute add - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("lorem").add("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") + "10" - ), + ) + == [] ) + assert medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").add("10")) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") + "10" + ) == ["0"] # Attribute compare to attribute sub # Returns nothing because can't sub a string - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("lorem").sub("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") == node().attribute("lorem") + "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").not_equal(node().attribute("lorem").sub("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") != node().attribute("lorem") - "10" - ), + ) + == [] ) # Attribute compare to attribute sub - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("integer").equal(node().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ), + ) + == [] ) + assert medrecord.select_nodes( + node().attribute("integer").not_equal(node().attribute("integer").sub(10)) + ) == ["0"] # Attribute compare to attribute mul - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("lorem").mul(2)) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") == node().attribute("lorem") * 2 - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") * 2 - ), + ) + == [] ) + assert medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").mul(2)) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") * 2 + ) == ["0"] # Attribute compare to attribute div # Returns nothing because can't div a string - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("lorem").div("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") == node().attribute("lorem") / "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").not_equal(node().attribute("lorem").div("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") != node().attribute("lorem") / "10" - ), + ) + == [] ) # Attribute compare to attribute div - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("integer").equal(node().attribute("integer").div(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ), + ) + == [] ) + assert medrecord.select_nodes( + node().attribute("integer").not_equal(node().attribute("integer").div(2)) + ) == ["0"] # Attribute compare to attribute pow # Returns nothing because can't pow a string - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("lorem").pow("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") == node().attribute("lorem") ** "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").not_equal(node().attribute("lorem").pow("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") != node().attribute("lorem") ** "10" - ), + ) + == [] ) # Attribute compare to attribute pow - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").pow(2)) + ) == ["0"] + assert ( medrecord.select_nodes( node() .attribute("integer") .not_equal(node().attribute("integer").pow(2)) - ), + ) + == [] ) # Attribute compare to attribute mod # Returns nothing because can't mod a string - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("lorem").mod("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") == node().attribute("lorem") % "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").not_equal(node().attribute("lorem").mod("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem") != node().attribute("lorem") % "10" - ), + ) + == [] ) # Attribute compare to attribute mod - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").mod(2)) + ) == ["0"] + assert ( medrecord.select_nodes( node() .attribute("integer") .not_equal(node().attribute("integer").mod(2)) - ), + ) + == [] ) # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - [], + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").round()) + ) == ["0"] + assert ( medrecord.select_nodes( node().attribute("lorem").not_equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").round()) - ), + ) + == [] ) + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("float").round()) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("float").not_equal(node().attribute("float").round()) + ) == ["0"] # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").ceil()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").ceil()) - ), - ) + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("float").ceil()) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("float").not_equal(node().attribute("float").ceil()) + ) == ["0"] # Attribute compare to attribute floor - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("integer").equal(node().attribute("float").floor()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").floor()) - ), + ) + == [] ) + assert medrecord.select_nodes( + node().attribute("float").not_equal(node().attribute("float").floor()) + ) == ["0"] # Attribute compare to attribute abs - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").abs()) + ) == ["0"] + assert ( medrecord.select_nodes( node().attribute("integer").not_equal(node().attribute("integer").abs()) - ), + ) + == [] ) # Attribute compare to attribute sqrt - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").sqrt()) + ) == ["0"] + assert ( medrecord.select_nodes( node() .attribute("integer") .not_equal(node().attribute("integer").sqrt()) - ), + ) + == [] ) # Attribute compare to attribute trim - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim()) - ), - ) + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("dolor").trim()) + ) == ["0"] # Attribute compare to attribute trim_start - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("dolor").trim_start()) - ), + ) + == [] ) # Attribute compare to attribute trim_end - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("dolor").trim_end()) - ), + ) + == [] ) # Attribute compare to attribute lowercase - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").lowercase()) - ), - ) + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("test").lowercase()) + ) == ["0"] # Attribute compare to attribute uppercase - self.assertEqual( - [], + assert ( medrecord.select_nodes( node().attribute("lorem").equal(node().attribute("test").uppercase()) - ), + ) + == [] ) - def test_select_edges_edge(self): + def test_select_edges_edge(self) -> None: medrecord = create_medrecord() medrecord.add_group("test", edges=[0]) # Edge connected to target - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().connected_target("2"))), + assert sorted([1, 2, 3]) == sorted( + medrecord.select_edges(edge().connected_target("2")) ) # Edge connected to source - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().connected_source("0"))), + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges(edge().connected_source("0")) ) # Edge connected - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().connected("1"))), - ) + assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().connected("1"))) # Edge in group - self.assertEqual( - [0], - medrecord.select_edges(edge().in_group("test")), - ) + assert medrecord.select_edges(edge().in_group("test")) == [0] # Edge has attribute - self.assertEqual( - [0], - medrecord.select_edges(edge().has_attribute("sed")), - ) + assert medrecord.select_edges(edge().has_attribute("sed")) == [0] # Edge connected to target with - self.assertEqual( - [0], - medrecord.select_edges( - edge().connected_target_with(node().index().equal("1")) - ), - ) + assert medrecord.select_edges( + edge().connected_target_with(node().index().equal("1")) + ) == [0] # Edge connected to source with - self.assertEqual( - sorted([0, 2, 3]), - sorted( - medrecord.select_edges( - edge().connected_source_with(node().index().equal("0")) - ) - ), + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges( + edge().connected_source_with(node().index().equal("0")) + ) ) # Edge connected with - self.assertEqual( - sorted([0, 1]), - sorted( - medrecord.select_edges(edge().connected_with(node().index().equal("1"))) - ), + assert sorted([0, 1]) == sorted( + medrecord.select_edges(edge().connected_with(node().index().equal("1"))) ) # Edge has parallel edges with - self.assertEqual( - sorted([2, 3]), - sorted( - medrecord.select_edges( - edge().has_parallel_edges_with(edge().has_attribute("test")) - ) - ), + assert sorted([2, 3]) == sorted( + medrecord.select_edges( + edge().has_parallel_edges_with(edge().has_attribute("test")) + ) ) # Edge has parallel edges with self comparison - self.assertEqual( - [2], - medrecord.select_edges( - edge().has_parallel_edges_with_self_comparison( - edge().attribute("test").equal(edge().attribute("test").sub(1)) - ) - ), - ) + assert medrecord.select_edges( + edge().has_parallel_edges_with_self_comparison( + edge().attribute("test").equal(edge().attribute("test").sub(1)) + ) + ) == [2] - def test_select_edges_edge_index(self): + def test_select_edges_edge_index(self) -> None: medrecord = create_medrecord() # Index greater - self.assertEqual( - sorted([2, 3]), - sorted(medrecord.select_edges(edge().index().greater(1))), + assert sorted([2, 3]) == sorted( + medrecord.select_edges(edge().index().greater(1)) ) # Index less - self.assertEqual( - [0], - medrecord.select_edges(edge().index().less(1)), - ) + assert medrecord.select_edges(edge().index().less(1)) == [0] # Index greater or equal - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().index().greater_or_equal(1))), + assert sorted([1, 2, 3]) == sorted( + medrecord.select_edges(edge().index().greater_or_equal(1)) ) # Index less or equal - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().index().less_or_equal(1))), + assert sorted([0, 1]) == sorted( + medrecord.select_edges(edge().index().less_or_equal(1)) ) # Index equal - self.assertEqual( - [1], - medrecord.select_edges(edge().index().equal(1)), - ) + assert medrecord.select_edges(edge().index().equal(1)) == [1] # Index not equal - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_equal(1))), + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges(edge().index().not_equal(1)) ) # Index in - self.assertEqual( - [1], - medrecord.select_edges(edge().index().is_in([1])), - ) + assert medrecord.select_edges(edge().index().is_in([1])) == [1] # Index not in - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_in([1]))), + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges(edge().index().not_in([1])) ) - def test_select_edges_edges_attribute(self): + def test_select_edges_edges_attribute(self) -> None: medrecord = create_medrecord() # Attribute greater - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").greater("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").greater("do")) == [] # Attribute less - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").less("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").less("do")) == [] # Attribute greater or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").greater_or_equal("do")), - ) + assert medrecord.select_edges( + edge().attribute("sed").greater_or_equal("do") + ) == [0] # Attribute less or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").less_or_equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").less_or_equal("do")) == [ + 0 + ] # Attribute equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").equal("do")) == [0] # Attribute not equal - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").not_equal("do")) == [] # Attribute in - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").is_in(["do"])), - ) + assert medrecord.select_edges(edge().attribute("sed").is_in(["do"])) == [0] # Attribute not in - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_in(["do"])), - ) + assert medrecord.select_edges(edge().attribute("sed").not_in(["do"])) == [] # Attribute starts with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").starts_with("d")), - ) + assert medrecord.select_edges(edge().attribute("sed").starts_with("d")) == [0] # Attribute ends with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").ends_with("o")), - ) + assert medrecord.select_edges(edge().attribute("sed").ends_with("o")) == [0] # Attribute contains - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").contains("d")), - ) + assert medrecord.select_edges(edge().attribute("sed").contains("d")) == [0] # Attribute compare to attribute - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed")) - ), - ) - self.assertEqual( - [], + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed")) + ) == [0] + assert ( medrecord.select_edges( edge().attribute("sed").not_equal(edge().attribute("sed")) - ), + ) + == [] ) # Attribute compare to attribute add - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("sed").add("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed") == edge().attribute("sed") + "10" - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") + "10" - ), + ) + == [] ) + assert medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").add("10")) + ) == [0] + assert medrecord.select_edges( + edge().attribute("sed") != edge().attribute("sed") + "10" + ) == [0] # Attribute compare to attribute sub # Returns nothing because can't sub a string - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("sed").sub("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed") == edge().attribute("sed") - "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").not_equal(edge().attribute("sed").sub("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed") != edge().attribute("sed") - "10" - ), + ) + == [] ) # Attribute compare to attribute sub - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("integer").equal(edge().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ), + ) + == [] ) + assert medrecord.select_edges( + edge().attribute("integer").not_equal(edge().attribute("integer").sub(10)) + ) == [2] # Attribute compare to attribute mul - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("sed").mul(2)) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed") == edge().attribute("sed") * 2 - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") * 2 - ), + ) + == [] ) + assert medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").mul(2)) + ) == [0] + assert medrecord.select_edges( + edge().attribute("sed") != edge().attribute("sed") * 2 + ) == [0] # Attribute compare to attribute div # Returns nothing because can't div a string - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("sed").div("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed") == edge().attribute("sed") / "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").not_equal(edge().attribute("sed").div("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed") != edge().attribute("sed") / "10" - ), + ) + == [] ) # Attribute compare to attribute div - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("integer").equal(edge().attribute("integer").div(2)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ), + ) + == [] ) + assert medrecord.select_edges( + edge().attribute("integer").not_equal(edge().attribute("integer").div(2)) + ) == [2] # Attribute compare to attribute pow # Returns nothing because can't pow a string - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem").equal(edge().attribute("lorem").pow("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem") == edge().attribute("lorem") ** "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem") != edge().attribute("lorem") ** "10" - ), + ) + == [] ) # Attribute compare to attribute pow - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").pow(2)) + ) == [2] + assert ( medrecord.select_edges( edge() .attribute("integer") .not_equal(edge().attribute("integer").pow(2)) - ), + ) + == [] ) # Attribute compare to attribute mod # Returns nothing because can't mod a string - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem").equal(edge().attribute("lorem").mod("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem") == edge().attribute("lorem") % "10" - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10")) - ), + ) + == [] ) - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("lorem") != edge().attribute("lorem") % "10" - ), + ) + == [] ) # Attribute compare to attribute mod - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").mod(2)) + ) == [2] + assert ( medrecord.select_edges( edge() .attribute("integer") .not_equal(edge().attribute("integer").mod(2)) - ), + ) + == [] ) # Attribute compare to attribute round - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [], + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed").round()) + ) == [0] + assert ( medrecord.select_edges( edge().attribute("sed").not_equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").round()) - ), + ) + == [] ) + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("float").round()) + ) == [2] + assert medrecord.select_edges( + edge().attribute("float").not_equal(edge().attribute("float").round()) + ) == [2] # Attribute compare to attribute ceil - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").ceil()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").ceil()) - ), - ) + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("float").ceil()) + ) == [2] + assert medrecord.select_edges( + edge().attribute("float").not_equal(edge().attribute("float").ceil()) + ) == [2] # Attribute compare to attribute floor - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("integer").equal(edge().attribute("float").floor()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").floor()) - ), + ) + == [] ) + assert medrecord.select_edges( + edge().attribute("float").not_equal(edge().attribute("float").floor()) + ) == [2] # Attribute compare to attribute abs - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").abs()) + ) == [2] + assert ( medrecord.select_edges( edge().attribute("integer").not_equal(edge().attribute("integer").abs()) - ), + ) + == [] ) # Attribute compare to attribute sqrt - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").sqrt()) + ) == [2] + assert ( medrecord.select_edges( edge() .attribute("integer") .not_equal(edge().attribute("integer").sqrt()) - ), + ) + == [] ) # Attribute compare to attribute trim - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim()) - ), - ) + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("dolor").trim()) + ) == [0] # Attribute compare to attribute trim_start - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("dolor").trim_start()) - ), + ) + == [] ) # Attribute compare to attribute trim_end - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("dolor").trim_end()) - ), + ) + == [] ) # Attribute compare to attribute lowercase - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").lowercase()) - ), - ) + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("test").lowercase()) + ) == [0] # Attribute compare to attribute uppercase - self.assertEqual( - [], + assert ( medrecord.select_edges( edge().attribute("sed").equal(edge().attribute("test").uppercase()) - ), + ) + == [] ) diff --git a/medmodels/medrecord/tests/test_schema.py b/medmodels/medrecord/tests/test_schema.py index f0dc98fd..e4a72064 100644 --- a/medmodels/medrecord/tests/test_schema.py +++ b/medmodels/medrecord/tests/test_schema.py @@ -1,5 +1,7 @@ import unittest +import pytest + import medmodels.medrecord as mr from medmodels._medmodels import PyAttributeType from medmodels.medrecord.schema import GroupSchema, Schema @@ -10,69 +12,60 @@ def create_medrecord() -> mr.MedRecord: class TestSchema(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.schema = create_medrecord().schema - def test_groups(self): - self.assertEqual( - sorted( - [ - "diagnosis", - "drug", - "patient_diagnosis", - "patient_drug", - "patient_procedure", - "patient", - "procedure", - ] - ), - sorted(self.schema.groups), - ) - - def test_group(self): - self.assertTrue(isinstance(self.schema.group("patient"), mr.GroupSchema)) # pyright: ignore[reportUnnecessaryIsInstance] - - with self.assertRaises(ValueError): + def test_groups(self) -> None: + assert sorted( + [ + "diagnosis", + "drug", + "patient_diagnosis", + "patient_drug", + "patient_procedure", + "patient", + "procedure", + ] + ) == sorted(self.schema.groups) + + def test_group(self) -> None: + assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance] + + with pytest.raises(ValueError): self.schema.group("nonexistent") - def test_default(self): - self.assertEqual(None, self.schema.default) + def test_default(self) -> None: + assert None is self.schema.default schema = Schema(default=GroupSchema(nodes={"description": mr.String()})) - self.assertTrue(isinstance(schema.default, mr.GroupSchema)) + assert isinstance(schema.default, mr.GroupSchema) - def test_strict(self): - self.assertEqual(True, self.schema.strict) + def test_strict(self) -> None: + assert True is self.schema.strict class TestGroupSchema(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.schema = create_medrecord().schema - def test_nodes(self): - self.assertEqual( - { - "age": (mr.Int(), mr.AttributeType.Continuous), - "gender": (mr.String(), mr.AttributeType.Categorical), - }, - self.schema.group("patient").nodes, - ) + def test_nodes(self) -> None: + assert self.schema.group("patient").nodes == { + "age": (mr.Int(), mr.AttributeType.Continuous), + "gender": (mr.String(), mr.AttributeType.Categorical), + } - def test_edges(self): - self.assertEqual( - { - "diagnosis_time": (mr.DateTime(), mr.AttributeType.Temporal), - "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous), - }, - self.schema.group("patient_diagnosis").edges, - ) + def test_edges(self) -> None: + assert self.schema.group("patient_diagnosis").edges == { + "diagnosis_time": (mr.DateTime(), mr.AttributeType.Temporal), + "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous), + } - def test_strict(self): - self.assertEqual(True, self.schema.group("patient").strict) + def test_strict(self) -> None: + assert True is self.schema.group("patient").strict class TestAttributesSchema(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.attributes_schema = ( Schema( groups={"diagnosis": GroupSchema(nodes={"description": mr.String()})}, @@ -82,10 +75,9 @@ def setUp(self): .nodes ) - def test_repr(self): - self.assertEqual( - "{'description': (DataType.String, None)}", - repr(self.attributes_schema), + def test_repr(self) -> None: + assert ( + repr(self.attributes_schema) == "{'description': (DataType.String, None)}" ) second_attributes_schema = ( @@ -103,28 +95,25 @@ def test_repr(self): .nodes ) - self.assertEqual( - "{'description': (DataType.String, AttributeType.Categorical)}", - repr(second_attributes_schema), + assert ( + repr(second_attributes_schema) + == "{'description': (DataType.String, AttributeType.Categorical)}" ) - def test_getitem(self): - self.assertEqual( - (mr.String(), None), - self.attributes_schema["description"], - ) + def test_getitem(self) -> None: + assert (mr.String(), None) == self.attributes_schema["description"] - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.attributes_schema["nonexistent"] - def test_contains(self): - self.assertTrue("description" in self.attributes_schema) - self.assertFalse("nonexistent" in self.attributes_schema) + def test_contains(self) -> None: + assert "description" in self.attributes_schema + assert "nonexistent" not in self.attributes_schema - def test_len(self): - self.assertEqual(1, len(self.attributes_schema)) + def test_len(self) -> None: + assert len(self.attributes_schema) == 1 - def test_eq(self): + def test_eq(self) -> None: comparison_attributes_schema = ( Schema( groups={"diagnosis": GroupSchema(nodes={"description": mr.String()})}, @@ -134,7 +123,7 @@ def test_eq(self): .nodes ) - self.assertEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema == comparison_attributes_schema comparison_attributes_schema = ( Schema( @@ -145,7 +134,7 @@ def test_eq(self): .nodes ) - self.assertNotEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema != comparison_attributes_schema comparison_attributes_schema = ( Schema( @@ -162,7 +151,7 @@ def test_eq(self): .nodes ) - self.assertNotEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema != comparison_attributes_schema comparison_attributes_schema = ( Schema( @@ -179,60 +168,49 @@ def test_eq(self): .nodes ) - self.assertNotEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema != comparison_attributes_schema - self.assertNotEqual(self.attributes_schema, None) + assert self.attributes_schema is not None - def test_keys(self): - self.assertEqual(["description"], list(self.attributes_schema.keys())) + def test_keys(self) -> None: + assert list(self.attributes_schema.keys()) == ["description"] - def test_values(self): - self.assertEqual( - [(mr.String(), None)], - list(self.attributes_schema.values()), - ) + def test_values(self) -> None: + assert [(mr.String(), None)] == list(self.attributes_schema.values()) - def test_items(self): - self.assertEqual( - [("description", (mr.String(), None))], - list(self.attributes_schema.items()), + def test_items(self) -> None: + assert [("description", (mr.String(), None))] == list( + self.attributes_schema.items() ) - def test_get(self): - self.assertEqual( - (mr.String(), None), - self.attributes_schema.get("description"), - ) + def test_get(self) -> None: + assert (mr.String(), None) == self.attributes_schema.get("description") - self.assertEqual( - None, - self.attributes_schema.get("nonexistent"), - ) + assert None is self.attributes_schema.get("nonexistent") - self.assertEqual( - (mr.String(), None), - self.attributes_schema.get("nonexistent", (mr.String(), None)), + assert (mr.String(), None) == self.attributes_schema.get( + "nonexistent", (mr.String(), None) ) class TestAttributeType(unittest.TestCase): - def test_str(self): - self.assertEqual("Categorical", str(mr.AttributeType.Categorical)) - self.assertEqual("Continuous", str(mr.AttributeType.Continuous)) - self.assertEqual("Temporal", str(mr.AttributeType.Temporal)) - - def test_eq(self): - self.assertEqual(mr.AttributeType.Categorical, mr.AttributeType.Categorical) - self.assertEqual(mr.AttributeType.Categorical, PyAttributeType.Categorical) - self.assertNotEqual(mr.AttributeType.Categorical, mr.AttributeType.Continuous) - self.assertNotEqual(mr.AttributeType.Categorical, PyAttributeType.Continuous) - - self.assertEqual(mr.AttributeType.Continuous, mr.AttributeType.Continuous) - self.assertEqual(mr.AttributeType.Continuous, PyAttributeType.Continuous) - self.assertNotEqual(mr.AttributeType.Continuous, mr.AttributeType.Categorical) - self.assertNotEqual(mr.AttributeType.Continuous, PyAttributeType.Categorical) - - self.assertEqual(mr.AttributeType.Temporal, mr.AttributeType.Temporal) - self.assertEqual(mr.AttributeType.Temporal, PyAttributeType.Temporal) - self.assertNotEqual(mr.AttributeType.Temporal, mr.AttributeType.Categorical) - self.assertNotEqual(mr.AttributeType.Temporal, PyAttributeType.Categorical) + def test_str(self) -> None: + assert str(mr.AttributeType.Categorical) == "Categorical" + assert str(mr.AttributeType.Continuous) == "Continuous" + assert str(mr.AttributeType.Temporal) == "Temporal" + + def test_eq(self) -> None: + assert mr.AttributeType.Categorical == mr.AttributeType.Categorical + assert mr.AttributeType.Categorical == PyAttributeType.Categorical + assert mr.AttributeType.Categorical != mr.AttributeType.Continuous + assert mr.AttributeType.Categorical != PyAttributeType.Continuous + + assert mr.AttributeType.Continuous == mr.AttributeType.Continuous + assert mr.AttributeType.Continuous == PyAttributeType.Continuous + assert mr.AttributeType.Continuous != mr.AttributeType.Categorical + assert mr.AttributeType.Continuous != PyAttributeType.Categorical + + assert mr.AttributeType.Temporal == mr.AttributeType.Temporal + assert mr.AttributeType.Temporal == PyAttributeType.Temporal + assert mr.AttributeType.Temporal != mr.AttributeType.Categorical + assert mr.AttributeType.Temporal != PyAttributeType.Categorical diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index 9a6e5be6..8615c744 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import Any, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional import medmodels.treatment_effect.treatment_effect as tee -from medmodels.medrecord.querying import NodeOperation -from medmodels.medrecord.types import ( - Group, - MedRecordAttribute, - MedRecordAttributeInputList, -) -from medmodels.treatment_effect.matching.algorithms.propensity_score import Model -from medmodels.treatment_effect.matching.matching import MatchingMethod + +if TYPE_CHECKING: + from medmodels.medrecord.querying import NodeOperation + from medmodels.medrecord.types import ( + Group, + MedRecordAttribute, + MedRecordAttributeInputList, + ) + from medmodels.treatment_effect.matching.algorithms.propensity_score import Model + from medmodels.treatment_effect.matching.matching import MatchingMethod class TreatmentEffectBuilder: @@ -218,8 +220,8 @@ def filter_controls(self, operation: NodeOperation) -> TreatmentEffectBuilder: def with_propensity_matching( self, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, model: Model = "logit", number_of_neighbors: int = 1, hyperparam: Optional[Dict[str, Any]] = None, @@ -244,6 +246,10 @@ def with_propensity_matching( TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated matching configurations. """ + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] self.matching_method = "propensity" self.matching_essential_covariates = essential_covariates self.matching_one_hot_covariates = one_hot_covariates @@ -255,8 +261,8 @@ def with_propensity_matching( def with_nearest_neighbors_matching( self, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, number_of_neighbors: int = 1, ) -> TreatmentEffectBuilder: """Adjust the treatment effect estimate using nearest neighbors matching. @@ -275,6 +281,10 @@ def with_nearest_neighbors_matching( TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated matching configurations. """ + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] self.matching_method = "nearest_neighbors" self.matching_essential_covariates = essential_covariates self.matching_one_hot_covariates = one_hot_covariates diff --git a/medmodels/treatment_effect/continuous_estimators.py b/medmodels/treatment_effect/continuous_estimators.py index 5aa89e06..a6cc5c34 100644 --- a/medmodels/treatment_effect/continuous_estimators.py +++ b/medmodels/treatment_effect/continuous_estimators.py @@ -74,7 +74,8 @@ def average_treatment_effect( ] ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) control_outcomes = np.array( [ @@ -91,7 +92,8 @@ def average_treatment_effect( ] ) if not all(isinstance(i, (int, float)) for i in control_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) return treated_outcomes.mean() - control_outcomes.mean() @@ -168,7 +170,8 @@ def cohens_d( ] ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) control_outcomes = np.array( [ @@ -185,7 +188,8 @@ def cohens_d( ] ) if not all(isinstance(i, (int, float)) for i in control_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) min_len = min(len(treated_outcomes), len(control_outcomes)) cf = 1 # correction factor diff --git a/medmodels/treatment_effect/estimate.py b/medmodels/treatment_effect/estimate.py index 44dae744..71980781 100644 --- a/medmodels/treatment_effect/estimate.py +++ b/medmodels/treatment_effect/estimate.py @@ -2,17 +2,17 @@ from typing import TYPE_CHECKING, Literal, Set, Tuple, TypedDict -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttribute, NodeIndex from medmodels.treatment_effect.continuous_estimators import ( average_treatment_effect, cohens_d, ) -from medmodels.treatment_effect.matching.matching import Matching from medmodels.treatment_effect.matching.neighbors import NeighborsMatching from medmodels.treatment_effect.matching.propensity import PropensityMatching if TYPE_CHECKING: + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import MedRecordAttribute, NodeIndex + from medmodels.treatment_effect.matching.matching import Matching from medmodels.treatment_effect.treatment_effect import TreatmentEffect @@ -128,20 +128,23 @@ def _check_medrecord(self, medrecord: MedRecord) -> None: MedRecord (patients, treatments, outcomes). """ if self._treatment_effect._patients_group not in medrecord.groups: - raise ValueError( + msg = ( f"Patient group {self._treatment_effect._patients_group} not found in " f"the MedRecord. Available groups: {medrecord.groups}" ) + raise ValueError(msg) if self._treatment_effect._treatments_group not in medrecord.groups: - raise ValueError( + msg = ( "Treatment group not found in the MedRecord. " f"Available groups: {medrecord.groups}" ) + raise ValueError(msg) if self._treatment_effect._outcomes_group not in medrecord.groups: - raise ValueError( + msg = ( "Outcome group not found in the MedRecord." f"Available groups: {medrecord.groups}" ) + raise ValueError(msg) def _sort_subjects_in_groups( self, medrecord: MedRecord @@ -531,9 +534,8 @@ def hazard_ratio(self, medrecord: MedRecord) -> float: ) if hazard_control == 0: - raise ValueError( - "Control hazard rate is zero, cannot calculate hazard ratio." - ) + msg = "Control hazard rate is zero, cannot calculate hazard ratio." + raise ValueError(msg) return hazard_treat / hazard_control diff --git a/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py b/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py index 9d6e9141..19970a28 100644 --- a/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py +++ b/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py @@ -35,9 +35,8 @@ def nearest_neighbor( unit. """ if treated_set.shape[0] * number_of_neighbors > control_set.shape[0]: - raise ValueError( - "The treated set is too large for the given number of neighbors." - ) + msg = "The treated set is too large for the given number of neighbors." + raise ValueError(msg) if not covariates: covariates = treated_set.columns diff --git a/medmodels/treatment_effect/matching/algorithms/propensity_score.py b/medmodels/treatment_effect/matching/algorithms/propensity_score.py index 92ca7c4b..221c2fad 100644 --- a/medmodels/treatment_effect/matching/algorithms/propensity_score.py +++ b/medmodels/treatment_effect/matching/algorithms/propensity_score.py @@ -4,12 +4,10 @@ import numpy as np import polars as pl -from numpy.typing import NDArray from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier -from medmodels.medrecord.types import MedRecordAttributeInputList from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -17,6 +15,10 @@ if TYPE_CHECKING: import sys + from numpy.typing import NDArray + + from medmodels.medrecord.types import MedRecordAttributeInputList + if sys.version_info >= (3, 10): from typing import TypeAlias else: diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index 5f276ea5..184e97ee 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -1,16 +1,16 @@ from __future__ import annotations -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, Set, Tuple import polars as pl -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex - if TYPE_CHECKING: import sys + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + if sys.version_info >= (3, 10): from typing import TypeAlias else: @@ -19,7 +19,7 @@ MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"] -class Matching(metaclass=ABCMeta): +class Matching(ABC): """The Base Class for matching.""" def _preprocess_data( diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 5fb71e9d..f0d14bd1 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import Set +from typing import TYPE_CHECKING, Set -from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) from medmodels.treatment_effect.matching.matching import Matching +if TYPE_CHECKING: + from medmodels import MedRecord + from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + class NeighborsMatching(Matching): """Class for the nearest neighbor matching. @@ -25,7 +27,7 @@ def __init__( self, *, number_of_neighbors: int = 1, - ): + ) -> None: """Initializes the nearest neighbors class. Args: @@ -40,8 +42,8 @@ def match_controls( medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, ) -> Set[NodeIndex]: """Matches the controls based on the nearest neighbor algorithm. @@ -57,6 +59,10 @@ def match_controls( Returns: Set[NodeIndex]: Node Ids of the matched controls. """ + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] data_treated, data_control = self._preprocess_data( medrecord=medrecord, control_group=control_group, diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index 3ce6cfa6..4625ecbc 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -1,12 +1,10 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, Optional, Set import numpy as np import polars as pl -from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -16,6 +14,10 @@ ) from medmodels.treatment_effect.matching.matching import Matching +if TYPE_CHECKING: + from medmodels import MedRecord + from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + class PropensityMatching(Matching): """Class for the propensity score matching. @@ -37,7 +39,7 @@ def __init__( model: Model = "logit", number_of_neighbors: int = 1, hyperparam: Optional[Dict[str, Any]] = None, - ): + ) -> None: """Initializes the propensity score class. Args: @@ -60,8 +62,8 @@ def match_controls( medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, ) -> Set[NodeIndex]: """Matches the controls based on propensity score matching. @@ -78,6 +80,10 @@ def match_controls( Set[NodeIndex]: Node Ids of the matched controls. """ # Preprocess the data + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] data_treated, data_control = self._preprocess_data( medrecord=medrecord, treated_group=treated_group, diff --git a/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py b/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py index ba2e6dc7..9243a08a 100644 --- a/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py +++ b/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py @@ -2,6 +2,7 @@ import numpy as np import polars as pl +import pytest from medmodels.treatment_effect.matching.algorithms import ( classic_distance_models as cdm, @@ -9,7 +10,7 @@ class TestClassicDistanceModels(unittest.TestCase): - def test_nearest_neighbor(self): + def test_nearest_neighbor(self) -> None: ########################################### # 1D example c_set = pl.DataFrame({"a": [1, 5, 1, 3]}) @@ -17,7 +18,7 @@ def test_nearest_neighbor(self): expected_result = pl.DataFrame({"a": [1.0, 5.0]}) result = cdm.nearest_neighbor(t_set, c_set) - self.assertTrue(result.equals(expected_result)) + assert result.equals(expected_result) ########################################### # 3D example with covariates @@ -29,7 +30,7 @@ def test_nearest_neighbor(self): expected_result = pl.DataFrame([[1.0, 3.0, 5.0]], schema=cols, orient="row") result = cdm.nearest_neighbor(t_set, c_set, covariates=covs) - self.assertTrue(result.equals(expected_result)) + assert result.equals(expected_result) # 2 nearest neighbors expected_abs_2nn = pl.DataFrame( @@ -38,21 +39,19 @@ def test_nearest_neighbor(self): result_abs_2nn = cdm.nearest_neighbor( t_set, c_set, covariates=covs, number_of_neighbors=2 ) - self.assertTrue(result_abs_2nn.equals(expected_abs_2nn)) + assert result_abs_2nn.equals(expected_abs_2nn) - def test_nearest_neighbor_value_error(self): - # Test case for checking the ValueError when all control units have been matched + def test_nearest_neighbor_value_error(self) -> None: + """Checking the ValueError when all control units have been matched.""" c_set = pl.DataFrame({"a": [1, 2]}) t_set = pl.DataFrame({"a": [1, 2, 3]}) - with self.assertRaises(ValueError) as context: + with pytest.raises( + ValueError, + match="The treated set is too large for the given number of neighbors.", + ): cdm.nearest_neighbor(t_set, c_set, number_of_neighbors=2) - self.assertEqual( - str(context.exception), - "The treated set is too large for the given number of neighbors.", - ) - if __name__ == "__main__": run_test = unittest.TestLoader().loadTestsFromTestCase(TestClassicDistanceModels) diff --git a/medmodels/treatment_effect/matching/tests/test_metrics.py b/medmodels/treatment_effect/matching/tests/test_metrics.py index 80c16f34..81074f0f 100644 --- a/medmodels/treatment_effect/matching/tests/test_metrics.py +++ b/medmodels/treatment_effect/matching/tests/test_metrics.py @@ -6,21 +6,17 @@ class TestMetrics(unittest.TestCase): - def test_absolute_metric(self): - self.assertEqual(metrics.absolute_metric(np.array([-2]), np.array([-1])), 1) - self.assertEqual( - metrics.absolute_metric(np.array([2, -1]), np.array([1, 3])), 5 - ) + def test_absolute_metric(self) -> None: + assert metrics.absolute_metric(np.array([-2]), np.array([-1])) == 1 + assert metrics.absolute_metric(np.array([2, -1]), np.array([1, 3])) == 5 - def test_exact_metric(self): - self.assertEqual(metrics.exact_metric(np.array([-2]), np.array([-2])), 0) - self.assertEqual(metrics.exact_metric(np.array([-2]), np.array([-1])), np.inf) - self.assertEqual(metrics.exact_metric(np.array([2, -1]), np.array([2, -1])), 0) - self.assertEqual( - metrics.exact_metric(np.array([2, -1]), np.array([2, 1])), np.inf - ) + def test_exact_metric(self) -> None: + assert metrics.exact_metric(np.array([-2]), np.array([-2])) == 0 + assert metrics.exact_metric(np.array([-2]), np.array([-1])) == np.inf + assert metrics.exact_metric(np.array([2, -1]), np.array([2, -1])) == 0 + assert metrics.exact_metric(np.array([2, -1]), np.array([2, 1])) == np.inf - def test_mahalanobis_metric(self): + def test_mahalanobis_metric(self) -> None: data = np.array( [[64, 66, 68, 69, 73], [580, 570, 590, 660, 600], [29, 33, 37, 46, 55]] ) diff --git a/medmodels/treatment_effect/matching/tests/test_propensity_score.py b/medmodels/treatment_effect/matching/tests/test_propensity_score.py index e505d6ab..af6a6400 100644 --- a/medmodels/treatment_effect/matching/tests/test_propensity_score.py +++ b/medmodels/treatment_effect/matching/tests/test_propensity_score.py @@ -8,7 +8,7 @@ class TestPropensityScore(unittest.TestCase): - def test_calculate_propensity(self): + def test_calculate_propensity(self) -> None: x, y = load_iris(return_X_y=True) # Set random state by each propensity estimator: @@ -48,7 +48,7 @@ def test_calculate_propensity(self): self.assertAlmostEqual(result_1[0], 0, places=2) self.assertAlmostEqual(result_2[0], 0, places=2) - def test_run_propensity_score(self): + def test_run_propensity_score(self) -> None: # Set random state by each propensity estimator: hyperparam = {"random_state": 1} hyperparam_logit = {"random_state": 1, "max_iter": 200} @@ -63,21 +63,21 @@ def test_run_propensity_score(self): result_logit = ps.run_propensity_score( treated_set, control_set, hyperparam=hyperparam_logit ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # dec_tree metric expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( treated_set, control_set, model="dec_tree", hyperparam=hyperparam ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # forest model expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( treated_set, control_set, model="forest", hyperparam=hyperparam ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) ########################################### # 3D example with covariates @@ -92,7 +92,7 @@ def test_run_propensity_score(self): result_logit = ps.run_propensity_score( treated_set, control_set, covariates=covs, hyperparam=hyperparam_logit ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # dec_tree model expected_logit = pl.DataFrame({"a": [1.0], "b": [3.0], "c": [5.0]}) @@ -103,7 +103,7 @@ def test_run_propensity_score(self): covariates=covs, hyperparam=hyperparam, ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # forest model expected_logit = pl.DataFrame({"a": [1.0], "b": [3.0], "c": [5.0]}) @@ -114,7 +114,7 @@ def test_run_propensity_score(self): covariates=covs, hyperparam=hyperparam, ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # using 2 nearest neighbors expected_logit = pl.DataFrame( @@ -125,7 +125,7 @@ def test_run_propensity_score(self): control_set, number_of_neighbors=2, ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) if __name__ == "__main__": diff --git a/medmodels/treatment_effect/report.py b/medmodels/treatment_effect/report.py index 26bcdad5..f9c50ce0 100644 --- a/medmodels/treatment_effect/report.py +++ b/medmodels/treatment_effect/report.py @@ -2,10 +2,9 @@ from typing import TYPE_CHECKING, Literal, TypedDict -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttribute - if TYPE_CHECKING: + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import MedRecordAttribute from medmodels.treatment_effect.treatment_effect import TreatmentEffect diff --git a/medmodels/treatment_effect/temporal_analysis.py b/medmodels/treatment_effect/temporal_analysis.py index f87d79d8..f8787e55 100644 --- a/medmodels/treatment_effect/temporal_analysis.py +++ b/medmodels/treatment_effect/temporal_analysis.py @@ -77,7 +77,8 @@ def find_reference_edge( edge_values = medrecord.edge[edges].values() if not all(time_attribute in edge_attribute for edge_attribute in edge_values): - raise ValueError("Time attribute not found in the edge attributes") + msg = "Time attribute not found in the edge attributes" + raise ValueError(msg) for edge in edges: edge_time = pd.to_datetime(str(medrecord.edge[edge][time_attribute])) @@ -89,7 +90,8 @@ def find_reference_edge( reference_time = edge_time if reference_edge is None: - raise ValueError(f"No edge found for node {node_index} in this MedRecord") + msg = f"No edge found for node {node_index} in this MedRecord" + raise ValueError(msg) return reference_edge @@ -171,7 +173,8 @@ def find_node_in_time_window( for edge in edges: edge_attributes = medrecord.edge[edge] if time_attribute not in edge_attributes: - raise ValueError("Time attribute not found in the edge attributes") + msg = "Time attribute not found in the edge attributes" + raise ValueError(msg) event_time = pd.to_datetime(str(edge_attributes[time_attribute])) time_difference = event_time - reference_time diff --git a/medmodels/treatment_effect/tests/test_continuous_estimators.py b/medmodels/treatment_effect/tests/test_continuous_estimators.py index f612d058..952c2d41 100644 --- a/medmodels/treatment_effect/tests/test_continuous_estimators.py +++ b/medmodels/treatment_effect/tests/test_continuous_estimators.py @@ -1,9 +1,10 @@ """Tests for the TreatmentEffect class in the treatment_effect module.""" import unittest -from typing import List +from typing import List, Optional import pandas as pd +import pytest from medmodels import MedRecord from medmodels.medrecord.types import NodeIndex @@ -37,8 +38,7 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: } ) - patients = patients.loc[patients["index"].isin(patient_list)] - return patients + return patients.loc[patients["index"].isin(patient_list)] def create_diagnoses() -> pd.DataFrame: @@ -47,13 +47,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - diagnoses = pd.DataFrame( + return pd.DataFrame( { "index": ["D1"], "name": ["Stroke"], } ) - return diagnoses def create_prescriptions() -> pd.DataFrame: @@ -62,13 +61,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - prescriptions = pd.DataFrame( + return pd.DataFrame( { "index": ["M1", "M2"], "name": ["Rivaroxaban", "Warfarin"], } ) - return prescriptions def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -108,8 +106,7 @@ def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -162,28 +159,19 @@ def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_medrecord( - patient_list: List[NodeIndex] = [ - "P1", - "P2", - "P3", - "P4", - "P5", - "P6", - "P7", - "P8", - "P9", - ], + patient_list: Optional[List[NodeIndex]] = None, ) -> MedRecord: """Creates a MedRecord object. Returns: MedRecord: A MedRecord object. """ + if patient_list is None: + patient_list = ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"] patients = create_patients(patient_list=patient_list) diagnoses = create_diagnoses() prescriptions = create_prescriptions() @@ -213,12 +201,12 @@ def create_medrecord( class TestContinuousEstimators(unittest.TestCase): """Class to test the continuous estimators.""" - def setUp(self): + def setUp(self) -> None: self.medrecord = create_medrecord() self.outcome_group = "Stroke" self.time_attribute = "time" - def test_average_treatment_effect(self): + def test_average_treatment_effect(self) -> None: ate_result = average_treatment_effect( self.medrecord, treatment_outcome_true_set=set({"P2", "P3"}), @@ -241,8 +229,8 @@ def test_average_treatment_effect(self): ) self.assertAlmostEqual(-0.15, ate_result) - def test_invalid_treatment_effect(self): - with self.assertRaisesRegex(ValueError, "Outcome variable must be numeric"): + def test_invalid_treatment_effect(self) -> None: + with pytest.raises(ValueError, match="Outcome variable must be numeric"): average_treatment_effect( self.medrecord, treatment_outcome_true_set=set({"P2", "P3"}), @@ -253,7 +241,7 @@ def test_invalid_treatment_effect(self): time_attribute=self.time_attribute, ) - def test_cohens_d(self): + def test_cohens_d(self) -> None: cohens_d_result = cohens_d( self.medrecord, treatment_outcome_true_set=set({"P2", "P3"}), @@ -288,8 +276,8 @@ def test_cohens_d(self): ) self.assertAlmostEqual(0, cohens_d_corrected) - def test_invalid_cohens_D(self): - with self.assertRaisesRegex(ValueError, "Outcome variable must be numeric"): + def test_invalid_cohens_D(self) -> None: + with pytest.raises(ValueError, match="Outcome variable must be numeric"): cohens_d( self.medrecord, treatment_outcome_true_set=set({"P2", "P3"}), diff --git a/medmodels/treatment_effect/tests/test_temporal_analysis.py b/medmodels/treatment_effect/tests/test_temporal_analysis.py index 41121f5d..35ac6b8c 100644 --- a/medmodels/treatment_effect/tests/test_temporal_analysis.py +++ b/medmodels/treatment_effect/tests/test_temporal_analysis.py @@ -1,7 +1,8 @@ import unittest -from typing import List +from typing import List, Optional import pandas as pd +import pytest from medmodels.medrecord.medrecord import MedRecord from medmodels.medrecord.types import NodeIndex @@ -29,8 +30,7 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: } ) - patients = patients.loc[patients["index"].isin(patient_list)] - return patients + return patients.loc[patients["index"].isin(patient_list)] def create_diagnoses() -> pd.DataFrame: @@ -39,13 +39,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - diagnoses = pd.DataFrame( + return pd.DataFrame( { "index": ["D1"], "name": ["Stroke"], } ) - return diagnoses def create_prescriptions() -> pd.DataFrame: @@ -54,13 +53,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - prescriptions = pd.DataFrame( + return pd.DataFrame( { "index": ["M1", "M2"], "name": ["Rivaroxaban", "Warfarin"], } ) - return prescriptions def create_edges(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -94,22 +92,19 @@ def create_edges(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_medrecord( - patient_list: List[NodeIndex] = [ - "P1", - "P2", - "P3", - ], + patient_list: Optional[List[NodeIndex]] = None, ) -> MedRecord: """Creates a MedRecord object. Returns: MedRecord: A MedRecord object. """ + if patient_list is None: + patient_list = ["P1", "P2", "P3"] patients = create_patients(patient_list=patient_list) diagnoses = create_diagnoses() prescriptions = create_prescriptions() @@ -137,17 +132,17 @@ def create_medrecord( class TestTreatmentEffect(unittest.TestCase): """""" - def setUp(self): + def setUp(self) -> None: self.medrecord = create_medrecord() - def test_find_reference_time(self): + def test_find_reference_time(self) -> None: edge = find_reference_edge( self.medrecord, node_index="P1", reference="last", connected_group="Rivaroxaban", ) - self.assertEqual(0, edge) + assert edge == 0 # adding medication time self.medrecord.add_edges(("M1", "P1", {"time": "2000-01-15"})) @@ -158,7 +153,7 @@ def test_find_reference_time(self): reference="last", connected_group="Rivaroxaban", ) - self.assertEqual(5, edge) + assert edge == 5 edge = find_reference_edge( self.medrecord, @@ -166,11 +161,11 @@ def test_find_reference_time(self): reference="first", connected_group="Rivaroxaban", ) - self.assertEqual(0, edge) + assert edge == 0 - def test_invalid_find_reference_time(self): - with self.assertRaisesRegex( - ValueError, "Time attribute not found in the edge attributes" + def test_invalid_find_reference_time(self) -> None: + with pytest.raises( + ValueError, match="Time attribute not found in the edge attributes" ): find_reference_edge( self.medrecord, @@ -181,8 +176,8 @@ def test_invalid_find_reference_time(self): ) node_index = "P2" - with self.assertRaisesRegex( - ValueError, f"No edge found for node {node_index} in this MedRecord" + with pytest.raises( + ValueError, match=f"No edge found for node {node_index} in this MedRecord" ): find_reference_edge( self.medrecord, @@ -192,7 +187,7 @@ def test_invalid_find_reference_time(self): time_attribute="time", ) - def test_node_in_time_window(self): + def test_node_in_time_window(self) -> None: # check if patient has outcome a year after treatment node_found = find_node_in_time_window( self.medrecord, @@ -204,7 +199,7 @@ def test_node_in_time_window(self): reference="last", time_attribute="time", ) - self.assertTrue(node_found) + assert node_found # check if patient has outcome 30 days after treatment node_found2 = find_node_in_time_window( @@ -217,11 +212,11 @@ def test_node_in_time_window(self): reference="last", time_attribute="time", ) - self.assertFalse(node_found2) + assert not node_found2 - def test_invalid_node_in_time_window(self): - with self.assertRaisesRegex( - ValueError, "Time attribute not found in the edge attributes" + def test_invalid_node_in_time_window(self) -> None: + with pytest.raises( + ValueError, match="Time attribute not found in the edge attributes" ): find_node_in_time_window( self.medrecord, diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index aca74858..287a409f 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -1,9 +1,10 @@ """Tests for the TreatmentEffect class in the treatment_effect module.""" import unittest -from typing import List +from typing import List, Optional import pandas as pd +import pytest from medmodels import MedRecord from medmodels.medrecord import edge, node @@ -36,8 +37,7 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: } ) - patients = patients.loc[patients["index"].isin(patient_list)] - return patients + return patients.loc[patients["index"].isin(patient_list)] def create_diagnoses() -> pd.DataFrame: @@ -46,13 +46,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - diagnoses = pd.DataFrame( + return pd.DataFrame( { "index": ["D1"], "name": ["Stroke"], } ) - return diagnoses def create_prescriptions() -> pd.DataFrame: @@ -61,13 +60,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - prescriptions = pd.DataFrame( + return pd.DataFrame( { "index": ["M1", "M2"], "name": ["Rivaroxaban", "Warfarin"], } ) - return prescriptions def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -107,8 +105,7 @@ def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -153,28 +150,19 @@ def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_medrecord( - patient_list: List[NodeIndex] = [ - "P1", - "P2", - "P3", - "P4", - "P5", - "P6", - "P7", - "P8", - "P9", - ], + patient_list: Optional[List[NodeIndex]] = None, ) -> MedRecord: """Creates a MedRecord object. Returns: MedRecord: A MedRecord object. """ + if patient_list is None: + patient_list = ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"] patients = create_patients(patient_list=patient_list) diagnoses = create_diagnoses() prescriptions = create_prescriptions() @@ -205,79 +193,65 @@ def assert_treatment_effects_equal( test_case: unittest.TestCase, treatment_effect1: TreatmentEffect, treatment_effect2: TreatmentEffect, -): - test_case.assertEqual( - treatment_effect1._treatments_group, treatment_effect2._treatments_group +) -> None: + assert treatment_effect1._treatments_group == treatment_effect2._treatments_group + assert treatment_effect1._outcomes_group == treatment_effect2._outcomes_group + assert treatment_effect1._patients_group == treatment_effect2._patients_group + assert treatment_effect1._time_attribute == treatment_effect2._time_attribute + assert ( + treatment_effect1._washout_period_days == treatment_effect2._washout_period_days ) - test_case.assertEqual( - treatment_effect1._outcomes_group, treatment_effect2._outcomes_group + assert ( + treatment_effect1._washout_period_reference + == treatment_effect2._washout_period_reference ) - test_case.assertEqual( - treatment_effect1._patients_group, treatment_effect2._patients_group + assert treatment_effect1._grace_period_days == treatment_effect2._grace_period_days + assert ( + treatment_effect1._grace_period_reference + == treatment_effect2._grace_period_reference ) - test_case.assertEqual( - treatment_effect1._time_attribute, treatment_effect2._time_attribute + assert ( + treatment_effect1._follow_up_period_days + == treatment_effect2._follow_up_period_days ) - test_case.assertEqual( - treatment_effect1._washout_period_days, treatment_effect2._washout_period_days + assert ( + treatment_effect1._follow_up_period_reference + == treatment_effect2._follow_up_period_reference ) - test_case.assertEqual( - treatment_effect1._washout_period_reference, - treatment_effect2._washout_period_reference, + assert ( + treatment_effect1._outcome_before_treatment_days + == treatment_effect2._outcome_before_treatment_days ) - test_case.assertEqual( - treatment_effect1._grace_period_days, treatment_effect2._grace_period_days + assert ( + treatment_effect1._filter_controls_operation + == treatment_effect2._filter_controls_operation ) - test_case.assertEqual( - treatment_effect1._grace_period_reference, - treatment_effect2._grace_period_reference, + assert treatment_effect1._matching_method == treatment_effect2._matching_method + assert ( + treatment_effect1._matching_essential_covariates + == treatment_effect2._matching_essential_covariates ) - test_case.assertEqual( - treatment_effect1._follow_up_period_days, - treatment_effect2._follow_up_period_days, + assert ( + treatment_effect1._matching_one_hot_covariates + == treatment_effect2._matching_one_hot_covariates ) - test_case.assertEqual( - treatment_effect1._follow_up_period_reference, - treatment_effect2._follow_up_period_reference, + assert treatment_effect1._matching_model == treatment_effect2._matching_model + assert ( + treatment_effect1._matching_number_of_neighbors + == treatment_effect2._matching_number_of_neighbors ) - test_case.assertEqual( - treatment_effect1._outcome_before_treatment_days, - treatment_effect2._outcome_before_treatment_days, - ) - test_case.assertEqual( - treatment_effect1._filter_controls_operation, - treatment_effect2._filter_controls_operation, - ) - test_case.assertEqual( - treatment_effect1._matching_method, treatment_effect2._matching_method - ) - test_case.assertEqual( - treatment_effect1._matching_essential_covariates, - treatment_effect2._matching_essential_covariates, - ) - test_case.assertEqual( - treatment_effect1._matching_one_hot_covariates, - treatment_effect2._matching_one_hot_covariates, - ) - test_case.assertEqual( - treatment_effect1._matching_model, treatment_effect2._matching_model - ) - test_case.assertEqual( - treatment_effect1._matching_number_of_neighbors, - treatment_effect2._matching_number_of_neighbors, - ) - test_case.assertEqual( - treatment_effect1._matching_hyperparam, treatment_effect2._matching_hyperparam + assert ( + treatment_effect1._matching_hyperparam == treatment_effect2._matching_hyperparam ) class TestTreatmentEffect(unittest.TestCase): """Class to test the TreatmentEffect class in the treatment_effect module.""" - def setUp(self): + def setUp(self) -> None: self.medrecord = create_medrecord() - def test_init(self): + def test_init(self) -> None: # Initialize TreatmentEffect object tee = TreatmentEffect( treatment="Rivaroxaban", @@ -293,7 +267,7 @@ def test_init(self): assert_treatment_effects_equal(self, tee, tee_builder) - def test_default_properties(self): + def test_default_properties(self) -> None: tee = TreatmentEffect( treatment="Rivaroxaban", outcome="Stroke", @@ -313,7 +287,7 @@ def test_default_properties(self): assert_treatment_effects_equal(self, tee, tee_builder) - def test_check_medrecord(self): + def test_check_medrecord(self) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -321,8 +295,8 @@ def test_check_medrecord(self): .build() ) - with self.assertRaisesRegex( - ValueError, "Treatment group not found in the MedRecord" + with pytest.raises( + ValueError, match="Treatment group not found in the MedRecord" ): tee.estimate._check_medrecord(medrecord=self.medrecord) @@ -333,8 +307,8 @@ def test_check_medrecord(self): .build() ) - with self.assertRaisesRegex( - ValueError, "Outcome group not found in the MedRecord" + with pytest.raises( + ValueError, match="Outcome group not found in the MedRecord" ): tee2.estimate._check_medrecord(medrecord=self.medrecord) @@ -347,12 +321,13 @@ def test_check_medrecord(self): .build() ) - with self.assertRaisesRegex( - ValueError, f"Patient group {patient_group} not found in the MedRecord" + with pytest.raises( + ValueError, + match=f"Patient group {patient_group} not found in the MedRecord", ): tee3.estimate._check_medrecord(medrecord=self.medrecord) - def test_find_treated_patients(self): + def test_find_treated_patients(self) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -361,7 +336,7 @@ def test_find_treated_patients(self): ) treated_group = tee._find_treated_patients(self.medrecord) - self.assertEqual(treated_group, set({"P2", "P3", "P6"})) + assert treated_group == set({"P2", "P3", "P6"}) # no treatment_group patients = set(self.medrecord.nodes_in_group("patients")) @@ -372,7 +347,7 @@ def test_find_treated_patients(self): ): tee.estimate._compute_subject_counts(medrecord=medrecord2) - def test_find_groups(self): + def test_find_groups(self) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -391,7 +366,7 @@ def test_find_groups(self): self.assertEqual(control_outcome_true, set({"P1", "P4", "P7"})) self.assertEqual(control_outcome_false, set({"P5", "P8", "P9"})) - def test_compute_subject_counts(self): + def test_compute_subject_counts(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -400,7 +375,7 @@ def test_compute_subject_counts(self): ) counts = tee.estimate._compute_subject_counts(self.medrecord) - self.assertEqual(counts, (2, 1, 3, 3)) + assert counts == (2, 1, 3, 3) def test_invalid_compute_subject_counts(self): tee = ( @@ -448,7 +423,7 @@ def test_invalid_compute_subject_counts(self): ): tee.estimate._compute_subject_counts(medrecord=medrecord4) - def test_subject_counts(self): + def test_subject_counts(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -495,7 +470,7 @@ def test_subjects_indices(self): subjects_tee["treated_outcome_true"], ) - def test_follow_up_period(self): + def test_follow_up_period(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -504,13 +479,13 @@ def test_follow_up_period(self): .build() ) - self.assertEqual(tee._follow_up_period_days, 30) + assert tee._follow_up_period_days == 30 counts_tee = tee.estimate._compute_subject_counts(self.medrecord) self.assertEqual((1, 2, 3, 3), counts_tee) - def test_grace_period(self): + def test_grace_period(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -519,13 +494,13 @@ def test_grace_period(self): .build() ) - self.assertEqual(tee._grace_period_days, 10) + assert tee._grace_period_days == 10 counts_tee = tee.estimate._compute_subject_counts(self.medrecord) self.assertEqual((1, 2, 3, 3), counts_tee) - def test_washout_period(self): + def test_washout_period(self) -> None: washout_dict = {"Warfarin": 30} tee = ( @@ -543,8 +518,8 @@ def test_washout_period(self): self.medrecord, treated_group ) - self.assertEqual(treated_group, set({"P3", "P6"})) - self.assertEqual(washout_nodes, set({"P2"})) + assert treated_group == set({"P3", "P6"}) + assert washout_nodes == set({"P2"}) # smaller washout period washout_dict2 = {"Warfarin": 10} @@ -564,10 +539,10 @@ def test_washout_period(self): self.medrecord, treated_group ) - self.assertEqual(treated_group, set({"P2", "P3", "P6"})) - self.assertEqual(washout_nodes, set({})) + assert treated_group == set({"P2", "P3", "P6"}) + assert washout_nodes == set({}) - def test_outcome_before_treatment(self): + def test_outcome_before_treatment(self) -> None: # case 1 find outcomes for default tee tee = ( TreatmentEffect.builder() @@ -592,7 +567,7 @@ def test_outcome_before_treatment(self): .build() ) - self.assertEqual(tee2._outcome_before_treatment_days, 30) + assert tee2._outcome_before_treatment_days == 30 treated_group = tee2._find_treated_patients(self.medrecord) treated_group, treatment_outcome_true, outcome_before_treatment_nodes = ( @@ -614,12 +589,12 @@ def test_outcome_before_treatment(self): .build() ) - with self.assertRaisesRegex( - ValueError, "No outcomes found in the MedRecord for group " + with pytest.raises( + ValueError, match="No outcomes found in the MedRecord for group " ): tee3._find_outcomes(medrecord=self.medrecord, treated_group=treated_group) - def test_filter_controls(self): + def test_filter_controls(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -647,7 +622,7 @@ def test_filter_controls(self): self.assertEqual(counts_tee2, (2, 1, 1, 1)) - def test_nearest_neighbors(self): + def test_nearest_neighbors(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -664,7 +639,7 @@ def test_nearest_neighbors(self): self.assertIn("P5", subjects["control_outcome_false"]) self.assertIn("P8", subjects["control_outcome_false"]) - def test_propensity_matching(self): + def test_propensity_matching(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -679,7 +654,7 @@ def test_propensity_matching(self): self.assertIn("P5", subjects["control_outcome_false"]) self.assertIn("P1", subjects["control_outcome_true"]) - def test_find_controls(self): + def test_find_controls(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -698,8 +673,8 @@ def test_find_controls(self): self.assertEqual(control_outcome_true, {"P1", "P4", "P7"}) self.assertEqual(control_outcome_false, {"P5", "P8", "P9"}) - with self.assertRaisesRegex( - ValueError, "No patients found for control groups in this MedRecord." + with pytest.raises( + ValueError, match="No patients found for control groups in this MedRecord." ): tee._find_controls( self.medrecord, @@ -717,8 +692,8 @@ def test_find_controls(self): self.medrecord.add_group("Headache") - with self.assertRaisesRegex( - ValueError, "No outcomes found in the MedRecord for group." + with pytest.raises( + ValueError, match="No outcomes found in the MedRecord for group." ): tee2._find_controls( self.medrecord, @@ -726,7 +701,7 @@ def test_find_controls(self): treated_group=patients.intersection(treated_group), ) - def test_metrics(self): + def test_metrics(self) -> None: """Test the metrics of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() @@ -745,7 +720,7 @@ def test_metrics(self): self.assertAlmostEqual(tee.estimate.hazard_ratio(self.medrecord), 4 / 3) self.assertAlmostEqual(tee.estimate.number_needed_to_treat(self.medrecord), -6) - def test_full_report(self): + def test_full_report(self) -> None: """Test the full reporting of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() @@ -771,7 +746,7 @@ def test_full_report(self): } self.assertDictEqual(report_test, full_report) - def test_continuous_estimators_report(self): + def test_continuous_estimators_report(self) -> None: """Test the continuous report of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 94a3980e..133174ae 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -11,24 +11,26 @@ from __future__ import annotations import logging -from typing import Any, Dict, Literal, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set, Tuple -from medmodels import MedRecord from medmodels.medrecord import node -from medmodels.medrecord.querying import NodeOperation -from medmodels.medrecord.types import ( - Group, - MedRecordAttribute, - MedRecordAttributeInputList, - NodeIndex, -) from medmodels.treatment_effect.builder import TreatmentEffectBuilder from medmodels.treatment_effect.estimate import Estimate -from medmodels.treatment_effect.matching.algorithms.propensity_score import Model -from medmodels.treatment_effect.matching.matching import MatchingMethod from medmodels.treatment_effect.report import Report from medmodels.treatment_effect.temporal_analysis import find_node_in_time_window +if TYPE_CHECKING: + from medmodels import MedRecord + from medmodels.medrecord.querying import NodeOperation + from medmodels.medrecord.types import ( + Group, + MedRecordAttribute, + MedRecordAttributeInputList, + NodeIndex, + ) + from medmodels.treatment_effect.matching.algorithms.propensity_score import Model + from medmodels.treatment_effect.matching.matching import MatchingMethod + class TreatmentEffect: """This class facilitates the analysis of treatment effects over time and across different patient groups.""" @@ -85,7 +87,7 @@ def _set_configuration( outcome: Group, patients_group: Group = "patients", time_attribute: MedRecordAttribute = "time", - washout_period_days: Dict[str, int] = dict(), + washout_period_days: Optional[Dict[str, int]] = None, washout_period_reference: Literal["first", "last"] = "first", grace_period_days: int = 0, grace_period_reference: Literal["first", "last"] = "last", @@ -94,8 +96,8 @@ def _set_configuration( outcome_before_treatment_days: Optional[int] = None, filter_controls_operation: Optional[NodeOperation] = None, matching_method: Optional[MatchingMethod] = None, - matching_essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - matching_one_hot_covariates: MedRecordAttributeInputList = ["gender"], + matching_essential_covariates: MedRecordAttributeInputList = None, + matching_one_hot_covariates: MedRecordAttributeInputList = None, matching_model: Model = "logit", matching_number_of_neighbors: int = 1, matching_hyperparam: Optional[Dict[str, Any]] = None, @@ -145,6 +147,12 @@ def _set_configuration( matching_hyperparam (Optional[Dict[str, Any]], optional): The hyperparameters for the matching model. Defaults to None. """ + if matching_one_hot_covariates is None: + matching_one_hot_covariates = ["gender"] + if washout_period_days is None: + washout_period_days = {} + if matching_essential_covariates is None: + matching_essential_covariates = ["gender", "age"] treatment_effect._patients_group = patients_group treatment_effect._time_attribute = time_attribute @@ -247,9 +255,8 @@ def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: ) ) if not treated_group: - raise ValueError( - "No patients found for the treatment groups in this MedRecord." - ) + msg = "No patients found for the treatment groups in this MedRecord." + raise ValueError(msg) return treated_group @@ -284,9 +291,8 @@ def _find_outcomes( # Find nodes with the outcomes outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: - raise ValueError( - f"No outcomes found in the MedRecord for group {self._outcomes_group}" - ) + msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" + raise ValueError(msg) for outcome in outcomes: nodes_to_check = set( @@ -398,7 +404,7 @@ def _find_controls( medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], - rejected_nodes: Set[NodeIndex] = set(), + rejected_nodes: Optional[Set[NodeIndex]] = None, filter_controls_operation: Optional[NodeOperation] = None, ) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: """Identifies control groups among patients who did not undergo the specified treatments. @@ -436,6 +442,8 @@ def _find_controls( outcome group. """ # Apply the filter to the control group if specified + if rejected_nodes is None: + rejected_nodes = set() if filter_controls_operation: control_group = ( set(medrecord.select_nodes(filter_controls_operation)) & control_group @@ -443,15 +451,15 @@ def _find_controls( control_group = control_group - treated_group - rejected_nodes if len(control_group) == 0: - raise ValueError("No patients found for control groups in this MedRecord.") + msg = "No patients found for control groups in this MedRecord." + raise ValueError(msg) control_outcome_true = set() control_outcome_false = set() outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: - raise ValueError( - f"No outcomes found in the MedRecord for group {self._outcomes_group}" - ) + msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" + raise ValueError(msg) # Finding the patients that had the outcome in the control group for outcome in outcomes: diff --git a/pyproject.toml b/pyproject.toml index 37a59853..a625878d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,70 @@ exclude = [ ] line-length = 88 +[tool.ruff.lint] +select = [ + "E", # PEP 8 codestyle errors + "F", # pyflakes + "I", # isort + "N", # PEP 8 naming + "DOC", # Pydoc Linting (preview); complementary to "D" + "D", # Pydoc Style; PEP 257 + "FA", # future annotations linting; PEP 563 + "W", # pycodestyle warnings; PEP 8 + "SIM", # flake8 simplify; simplify code + "ANN", # flake8 function annotations; PEP 3107 + "B", # bugbear extension for flake8; opinionated, not based on any PEP + "C4", # list/set/dict comprehensions + "T10", # Check for debugging leftovers: pdb;idbp imports and set traces + "EM", # error messages + "LOG", # logging module usage linting + "G", # logging format strings + "T20", # print statements + "PYI", # lint stub files .pyi + "PT", # pytest linting + "RET", # return values + "TCH", # type checking + "PTH", # pathlib usage + "PERF", # performance linting + "FURB", # modern python code patterns + "RUF", # ruff specific rules + "FBT", # no bool as function param + "TD", # todo linting + "C90", # mccabe complexity +] +preview = true +ignore = [ + "E501", # Line length managed by formatter + # indentation linters conflicting with formatter: + "W191", + "E111", + "E114", + "E117", + "D206", + # quotation linters conflicting with formatter: + "D300", + "Q000", + "Q001", + "Q002", + "Q003", + # comma linters conflicting with formatter: + "COM812", + "COM819", + # string concatenation linters conflicting with formatter: + "ISC001", + "ISC002", +] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + [tool.pyright] typeCheckingMode = "strict" reportPrivateUsage = false