diff --git a/medmodels/medrecord/_overview.py b/medmodels/medrecord/_overview.py index e334d25..f487646 100644 --- a/medmodels/medrecord/_overview.py +++ b/medmodels/medrecord/_overview.py @@ -232,10 +232,12 @@ def prettify_table( "-" * (sum(lengths) + len(lengths)), ] - table.extend([ - "".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths))) - for row in rows - ]) + table.extend( + [ + "".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths))) + for row in rows + ] + ) table.append("-" * (sum(lengths) + len(lengths))) diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index 5acd218..7d704e9 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -328,10 +328,12 @@ def _convert_node( else None, ) - return AttributesSchema({ - x: _convert_node(self._group_schema.nodes[x]) - for x in self._group_schema.nodes - }) + return AttributesSchema( + { + x: _convert_node(self._group_schema.nodes[x]) + for x in self._group_schema.nodes + } + ) @property def edges(self) -> AttributesSchema: @@ -352,10 +354,12 @@ def _convert_edge( else None, ) - return AttributesSchema({ - x: _convert_edge(self._group_schema.edges[x]) - for x in self._group_schema.edges - }) + return AttributesSchema( + { + x: _convert_edge(self._group_schema.edges[x]) + for x in self._group_schema.edges + } + ) @property def strict(self) -> Optional[bool]: diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index d488647..5db2b22 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -31,33 +31,41 @@ def create_edges() -> List[Tuple[NodeIndex, NodeIndex, Attributes]]: def create_pandas_nodes_dataframe() -> pd.DataFrame: - return pd.DataFrame({ - "index": ["0", "1"], - "attribute": [1, 2], - }) + return pd.DataFrame( + { + "index": ["0", "1"], + "attribute": [1, 2], + } + ) def create_second_pandas_nodes_dataframe() -> pd.DataFrame: - return pd.DataFrame({ - "index": ["2", "3"], - "attribute": [2, 3], - }) + return pd.DataFrame( + { + "index": ["2", "3"], + "attribute": [2, 3], + } + ) def create_pandas_edges_dataframe() -> pd.DataFrame: - return pd.DataFrame({ - "source": ["0", "1"], - "target": ["1", "0"], - "attribute": [1, 2], - }) + return pd.DataFrame( + { + "source": ["0", "1"], + "target": ["1", "0"], + "attribute": [1, 2], + } + ) def create_second_pandas_edges_dataframe() -> pd.DataFrame: - return pd.DataFrame({ - "source": ["0", "1"], - "target": ["1", "0"], - "attribute": [2, 3], - }) + return pd.DataFrame( + { + "source": ["0", "1"], + "target": ["1", "0"], + "attribute": [2, 3], + } + ) def create_medrecord() -> MedRecord: @@ -277,11 +285,13 @@ def test_schema(self) -> None: medrecord.add_edges_to_group("group", edge_index) - edge_index = medrecord.add_edges(( - "0", - "1", - {"attribute": 1, "attribute2": "1"}, - )) + edge_index = medrecord.add_edges( + ( + "0", + "1", + {"attribute": 1, "attribute2": "1"}, + ) + ) with self.assertRaises(ValueError): medrecord.add_edges_to_group("group", edge_index) @@ -602,10 +612,12 @@ def test_add_nodes(self) -> None: assert medrecord.node_count() == 0 - medrecord.add_nodes([ - (create_pandas_nodes_dataframe(), "index"), - (create_second_pandas_nodes_dataframe(), "index"), - ]) + medrecord.add_nodes( + [ + (create_pandas_nodes_dataframe(), "index"), + (create_second_pandas_nodes_dataframe(), "index"), + ] + ) assert medrecord.node_count() == 4 @@ -649,10 +661,12 @@ def test_add_nodes(self) -> None: assert medrecord.node_count() == 0 - medrecord.add_nodes([ - (nodes, "index"), - (second_nodes, "index"), - ]) + medrecord.add_nodes( + [ + (nodes, "index"), + (second_nodes, "index"), + ] + ) assert medrecord.node_count() == 4 @@ -908,10 +922,12 @@ def test_add_edges(self) -> None: assert medrecord.edge_count() == 0 - medrecord.add_edges([ - (create_pandas_edges_dataframe(), "source", "target"), - (create_second_pandas_edges_dataframe(), "source", "target"), - ]) + medrecord.add_edges( + [ + (create_pandas_edges_dataframe(), "source", "target"), + (create_second_pandas_edges_dataframe(), "source", "target"), + ] + ) assert medrecord.edge_count() == 4 @@ -942,10 +958,12 @@ def test_add_edges(self) -> None: second_edges = pl.from_pandas(create_second_pandas_edges_dataframe()) - medrecord.add_edges([ - (edges, "source", "target"), - (second_edges, "source", "target"), - ]) + medrecord.add_edges( + [ + (edges, "source", "target"), + (second_edges, "source", "target"), + ] + ) assert medrecord.edge_count() == 4 diff --git a/medmodels/medrecord/tests/test_overview.py b/medmodels/medrecord/tests/test_overview.py index 590e31f..ff822bd 100644 --- a/medmodels/medrecord/tests/test_overview.py +++ b/medmodels/medrecord/tests/test_overview.py @@ -11,25 +11,31 @@ def create_medrecord(): - patients = pd.DataFrame({ - "index": ["P1", "P2", "P3"], - "age": [20, 30, 70], - }) + patients = pd.DataFrame( + { + "index": ["P1", "P2", "P3"], + "age": [20, 30, 70], + } + ) diagnosis = pd.DataFrame({"index": ["D1", "D2"]}) - prescriptions = pd.DataFrame({ - "index": ["M1", "M2", "M3"], - "ATC": ["B01AF01", "B01AA03", np.nan], - }) + prescriptions = pd.DataFrame( + { + "index": ["M1", "M2", "M3"], + "ATC": ["B01AF01", "B01AA03", np.nan], + } + ) nodes = [patients, diagnosis, prescriptions] - edges = pd.DataFrame({ - "source": ["D1", "M1", "D1"], - "target": ["P1", "P2", "P3"], - "time": ["2000-01-01", "1999-10-15", "1999-12-15"], - }) + edges = pd.DataFrame( + { + "source": ["D1", "M1", "D1"], + "target": ["P1", "P2", "P3"], + "time": ["2000-01-01", "1999-10-15", "1999-12-15"], + } + ) edges.time = pd.to_datetime(edges.time) diff --git a/medmodels/medrecord/tests/test_schema.py b/medmodels/medrecord/tests/test_schema.py index 3937f19..e4a7206 100644 --- a/medmodels/medrecord/tests/test_schema.py +++ b/medmodels/medrecord/tests/test_schema.py @@ -16,15 +16,17 @@ def setUp(self) -> None: self.schema = create_medrecord().schema def test_groups(self) -> None: - assert sorted([ - "diagnosis", - "drug", - "patient_diagnosis", - "patient_drug", - "patient_procedure", - "patient", - "procedure", - ]) == sorted(self.schema.groups) + assert sorted( + [ + "diagnosis", + "drug", + "patient_diagnosis", + "patient_drug", + "patient_procedure", + "patient", + "procedure", + ] + ) == sorted(self.schema.groups) def test_group(self) -> None: assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance] diff --git a/medmodels/treatment_effect/continuous_estimators.py b/medmodels/treatment_effect/continuous_estimators.py index 0cbdaf5..a6cc5c3 100644 --- a/medmodels/treatment_effect/continuous_estimators.py +++ b/medmodels/treatment_effect/continuous_estimators.py @@ -59,34 +59,38 @@ def average_treatment_effect( Raises: ValueError: If the outcome variable is not numeric. """ - treated_outcomes = np.array([ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ]) + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): msg = "Outcome variable must be numeric" raise ValueError(msg) - control_outcomes = np.array([ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute="time", - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ]) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) if not all(isinstance(i, (int, float)) for i in control_outcomes): msg = "Outcome variable must be numeric" raise ValueError(msg) @@ -151,34 +155,38 @@ def cohens_d( Raises: ValueError: If the outcome variable is not numeric. """ - treated_outcomes = np.array([ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute=time_attribute, - reference=reference, - ) - ][outcome_variable] - for node_index in treatment_outcome_true_set - ]) + treated_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute=time_attribute, + reference=reference, + ) + ][outcome_variable] + for node_index in treatment_outcome_true_set + ] + ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): msg = "Outcome variable must be numeric" raise ValueError(msg) - control_outcomes = np.array([ - medrecord.edge[ - find_reference_edge( - medrecord, - node_index, - outcome_group, - time_attribute="time", - reference=reference, - ) - ][outcome_variable] - for node_index in control_outcome_true_set - ]) + control_outcomes = np.array( + [ + medrecord.edge[ + find_reference_edge( + medrecord, + node_index, + outcome_group, + time_attribute="time", + reference=reference, + ) + ][outcome_variable] + for node_index in control_outcome_true_set + ] + ) if not all(isinstance(i, (int, float)) for i in control_outcomes): msg = "Outcome variable must be numeric" raise ValueError(msg) diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index eaa3def..04418d7 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -97,10 +97,12 @@ def match_controls( # Train the classification model x_train = np.concatenate((treated_array, control_array)) - y_train = np.concatenate(( - np.ones(len(treated_array)), - np.zeros(len(control_array)), - )) + y_train = np.concatenate( + ( + np.ones(len(treated_array)), + np.zeros(len(control_array)), + ) + ) treated_prop, control_prop = calculate_propensity( x_train=x_train, diff --git a/medmodels/treatment_effect/matching/tests/test_metrics.py b/medmodels/treatment_effect/matching/tests/test_metrics.py index a9e195f..bc052ba 100644 --- a/medmodels/treatment_effect/matching/tests/test_metrics.py +++ b/medmodels/treatment_effect/matching/tests/test_metrics.py @@ -17,11 +17,13 @@ def test_exact_metric(self) -> None: assert metrics.exact_metric(np.array([2, -1]), np.array([2, 1])) == np.inf def test_mahalanobis_metric(self) -> None: - data = np.array([ - [64, 66, 68, 69, 73], - [580, 570, 590, 660, 600], - [29, 33, 37, 46, 55], - ]) + data = np.array( + [ + [64, 66, 68, 69, 73], + [580, 570, 590, 660, 600], + [29, 33, 37, 46, 55], + ] + ) inv_cov = np.linalg.inv(np.cov(data)) a1, a2 = np.array([68, 600, 40]), np.array([66, 640, 44]) result = metrics.mahalanobis_metric(a1, a2, inv_cov=inv_cov) diff --git a/medmodels/treatment_effect/matching/tests/test_propensity_score.py b/medmodels/treatment_effect/matching/tests/test_propensity_score.py index a1fc594..a823822 100644 --- a/medmodels/treatment_effect/matching/tests/test_propensity_score.py +++ b/medmodels/treatment_effect/matching/tests/test_propensity_score.py @@ -117,11 +117,13 @@ def test_run_propensity_score(self) -> None: assert result_logit.equals(expected_logit) # using 2 nearest neighbors - expected_logit = pl.DataFrame({ - "a": [1.0, 5.0], - "b": [3.0, 2.0], - "c": [5.0, 1.0], - }) + expected_logit = pl.DataFrame( + { + "a": [1.0, 5.0], + "b": [3.0, 2.0], + "c": [5.0, 1.0], + } + ) result_logit = ps.run_propensity_score( treated_set, control_set, diff --git a/medmodels/treatment_effect/tests/test_continuous_estimators.py b/medmodels/treatment_effect/tests/test_continuous_estimators.py index 16a84fb..952c2d4 100644 --- a/medmodels/treatment_effect/tests/test_continuous_estimators.py +++ b/medmodels/treatment_effect/tests/test_continuous_estimators.py @@ -20,21 +20,23 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: A patients dataframe. """ - patients = pd.DataFrame({ - "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"], - "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], - "gender": [ - "male", - "female", - "male", - "female", - "male", - "female", - "male", - "female", - "male", - ], - }) + patients = pd.DataFrame( + { + "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"], + "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], + "gender": [ + "male", + "female", + "male", + "female", + "male", + "female", + "male", + "female", + "male", + ], + } + ) return patients.loc[patients["index"].isin(patient_list)] @@ -45,10 +47,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - return pd.DataFrame({ - "index": ["D1"], - "name": ["Stroke"], - }) + return pd.DataFrame( + { + "index": ["D1"], + "name": ["Stroke"], + } + ) def create_prescriptions() -> pd.DataFrame: @@ -57,10 +61,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - return pd.DataFrame({ - "index": ["M1", "M2"], - "name": ["Rivaroxaban", "Warfarin"], - }) + return pd.DataFrame( + { + "index": ["M1", "M2"], + "name": ["Rivaroxaban", "Warfarin"], + } + ) def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -69,35 +75,37 @@ def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: An edges dataframe. """ - edges = pd.DataFrame({ - "source": [ - "M2", - "M1", - "M2", - "M1", - "M2", - "M1", - "M2", - ], - "target": [ - "P1", - "P2", - "P2", - "P3", - "P5", - "P6", - "P9", - ], - "time": [ - "1999-10-15", - "2000-01-01", - "1999-12-15", - "2000-01-01", - "2000-01-01", - "2000-01-01", - "2000-01-01", - ], - }) + edges = pd.DataFrame( + { + "source": [ + "M2", + "M1", + "M2", + "M1", + "M2", + "M1", + "M2", + ], + "target": [ + "P1", + "P2", + "P2", + "P3", + "P5", + "P6", + "P9", + ], + "time": [ + "1999-10-15", + "2000-01-01", + "1999-12-15", + "2000-01-01", + "2000-01-01", + "2000-01-01", + "2000-01-01", + ], + } + ) return edges.loc[edges["target"].isin(patient_list)] @@ -107,48 +115,50 @@ def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: An edges dataframe. """ - edges = pd.DataFrame({ - "source": [ - "D1", - "D1", - "D1", - "D1", - "D1", - "D1", - ], - "target": [ - "P1", - "P2", - "P3", - "P3", - "P4", - "P7", - ], - "time": [ - "2000-01-01", - "2000-07-01", - "1999-12-15", - "2000-01-05", - "2000-01-01", - "2000-01-01", - ], - "intensity": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - ], - "type": [ - "A", - "B", - "A", - "B", - "A", - "A", - ], - }) + edges = pd.DataFrame( + { + "source": [ + "D1", + "D1", + "D1", + "D1", + "D1", + "D1", + ], + "target": [ + "P1", + "P2", + "P3", + "P3", + "P4", + "P7", + ], + "time": [ + "2000-01-01", + "2000-07-01", + "1999-12-15", + "2000-01-05", + "2000-01-01", + "2000-01-01", + ], + "intensity": [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + ], + "type": [ + "A", + "B", + "A", + "B", + "A", + "A", + ], + } + ) return edges.loc[edges["target"].isin(patient_list)] diff --git a/medmodels/treatment_effect/tests/test_temporal_analysis.py b/medmodels/treatment_effect/tests/test_temporal_analysis.py index af867fc..35ac6b8 100644 --- a/medmodels/treatment_effect/tests/test_temporal_analysis.py +++ b/medmodels/treatment_effect/tests/test_temporal_analysis.py @@ -18,15 +18,17 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: A patients dataframe. """ - patients = pd.DataFrame({ - "index": ["P1", "P2", "P3"], - "age": [20, 30, 40], - "gender": [ - "male", - "female", - "male", - ], - }) + patients = pd.DataFrame( + { + "index": ["P1", "P2", "P3"], + "age": [20, 30, 40], + "gender": [ + "male", + "female", + "male", + ], + } + ) return patients.loc[patients["index"].isin(patient_list)] @@ -37,10 +39,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - return pd.DataFrame({ - "index": ["D1"], - "name": ["Stroke"], - }) + return pd.DataFrame( + { + "index": ["D1"], + "name": ["Stroke"], + } + ) def create_prescriptions() -> pd.DataFrame: @@ -49,10 +53,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - return pd.DataFrame({ - "index": ["M1", "M2"], - "name": ["Rivaroxaban", "Warfarin"], - }) + return pd.DataFrame( + { + "index": ["M1", "M2"], + "name": ["Rivaroxaban", "Warfarin"], + } + ) def create_edges(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -61,29 +67,31 @@ def create_edges(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: An edges dataframe. """ - edges = pd.DataFrame({ - "source": [ - "M1", - "M2", - "M1", - "M2", - "D1", - ], - "target": [ - "P1", - "P2", - "P3", - "P3", - "P3", - ], - "time": [ - "2000-01-01", - "2000-01-01", - "2000-01-01", - "1999-12-15", - "2000-07-01", - ], - }) + edges = pd.DataFrame( + { + "source": [ + "M1", + "M2", + "M1", + "M2", + "D1", + ], + "target": [ + "P1", + "P2", + "P3", + "P3", + "P3", + ], + "time": [ + "2000-01-01", + "2000-01-01", + "2000-01-01", + "1999-12-15", + "2000-07-01", + ], + } + ) return edges.loc[edges["target"].isin(patient_list)] diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index 21c5db7..e436e8c 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -19,21 +19,23 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: A patients dataframe. """ - patients = pd.DataFrame({ - "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"], - "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], - "gender": [ - "male", - "female", - "male", - "female", - "male", - "female", - "male", - "female", - "male", - ], - }) + patients = pd.DataFrame( + { + "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"], + "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], + "gender": [ + "male", + "female", + "male", + "female", + "male", + "female", + "male", + "female", + "male", + ], + } + ) return patients.loc[patients["index"].isin(patient_list)] @@ -44,10 +46,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - return pd.DataFrame({ - "index": ["D1"], - "name": ["Stroke"], - }) + return pd.DataFrame( + { + "index": ["D1"], + "name": ["Stroke"], + } + ) def create_prescriptions() -> pd.DataFrame: @@ -56,10 +60,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - return pd.DataFrame({ - "index": ["M1", "M2"], - "name": ["Rivaroxaban", "Warfarin"], - }) + return pd.DataFrame( + { + "index": ["M1", "M2"], + "name": ["Rivaroxaban", "Warfarin"], + } + ) def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -68,35 +74,37 @@ def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: An edges dataframe. """ - edges = pd.DataFrame({ - "source": [ - "M2", - "M1", - "M2", - "M1", - "M2", - "M1", - "M2", - ], - "target": [ - "P1", - "P2", - "P2", - "P3", - "P5", - "P6", - "P9", - ], - "time": [ - "1999-10-15", - "2000-01-01", - "1999-12-15", - "2000-01-01", - "2000-01-01", - "2000-01-01", - "2000-01-01", - ], - }) + edges = pd.DataFrame( + { + "source": [ + "M2", + "M1", + "M2", + "M1", + "M2", + "M1", + "M2", + ], + "target": [ + "P1", + "P2", + "P2", + "P3", + "P5", + "P6", + "P9", + ], + "time": [ + "1999-10-15", + "2000-01-01", + "1999-12-15", + "2000-01-01", + "2000-01-01", + "2000-01-01", + "2000-01-01", + ], + } + ) return edges.loc[edges["target"].isin(patient_list)] @@ -106,40 +114,42 @@ def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: Returns: pd.DataFrame: An edges dataframe. """ - edges = pd.DataFrame({ - "source": [ - "D1", - "D1", - "D1", - "D1", - "D1", - "D1", - ], - "target": [ - "P1", - "P2", - "P3", - "P3", - "P4", - "P7", - ], - "time": [ - "2000-01-01", - "2000-07-01", - "1999-12-15", - "2000-01-05", - "2000-01-01", - "2000-01-01", - ], - "intensity": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - ], - }) + edges = pd.DataFrame( + { + "source": [ + "D1", + "D1", + "D1", + "D1", + "D1", + "D1", + ], + "target": [ + "P1", + "P2", + "P3", + "P3", + "P4", + "P7", + ], + "time": [ + "2000-01-01", + "2000-07-01", + "1999-12-15", + "2000-01-05", + "2000-01-01", + "2000-01-01", + ], + "intensity": [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + ], + } + ) return edges.loc[edges["target"].isin(patient_list)] @@ -406,12 +416,14 @@ def test_invalid_compute_subject_counts(self): control_outcome_true, control_outcome_false, ) = tee._find_groups(self.medrecord) - all_patients = set().union(*[ - treatment_outcome_true, - treatment_outcome_false, - control_outcome_true, - control_outcome_false, - ]) + all_patients = set().union( + *[ + treatment_outcome_true, + treatment_outcome_false, + control_outcome_true, + control_outcome_false, + ] + ) medrecord2 = create_medrecord( patient_list=list(all_patients - control_outcome_false) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index eaa6bf2..45910e4 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -309,7 +309,26 @@ def query(node: NodeOperand): # Find patients that had the outcome before the treatment if self._outcome_before_treatment_days: - outcome_before_treatment_nodes.update({ + outcome_before_treatment_nodes.update( + { + node_index + for node_index in nodes_to_check + if find_node_in_time_window( + medrecord, + node_index, + outcome, + connected_group=self._treatments_group, + start_days=-self._outcome_before_treatment_days, + end_days=0, + reference="first", + ) + } + ) + nodes_to_check -= outcome_before_treatment_nodes + + # Find patients that had the outcome after the treatment + treatment_outcome_true.update( + { node_index for node_index in nodes_to_check if find_node_in_time_window( @@ -317,27 +336,12 @@ def query(node: NodeOperand): node_index, outcome, connected_group=self._treatments_group, - start_days=-self._outcome_before_treatment_days, - end_days=0, - reference="first", + start_days=self._grace_period_days, + end_days=self._follow_up_period_days, + reference=self._follow_up_period_reference, ) - }) - nodes_to_check -= outcome_before_treatment_nodes - - # Find patients that had the outcome after the treatment - treatment_outcome_true.update({ - node_index - for node_index in nodes_to_check - if find_node_in_time_window( - medrecord, - node_index, - outcome, - connected_group=self._treatments_group, - start_days=self._grace_period_days, - end_days=self._follow_up_period_days, - reference=self._follow_up_period_reference, - ) - }) + } + ) treated_group -= outcome_before_treatment_nodes if outcome_before_treatment_nodes: @@ -373,19 +377,21 @@ def _apply_washout_period( # TODO: washout in both directions? We need a List then for washout_group_id, washout_days in self._washout_period_days.items(): for washout_node in medrecord.nodes_in_group(washout_group_id): - washout_nodes.update({ - treated_node - for treated_node in treated_group - if find_node_in_time_window( - medrecord, - treated_node, - washout_node, - connected_group=self._treatments_group, - start_days=-washout_days, - end_days=0, - reference=self._washout_period_reference, - ) - }) + washout_nodes.update( + { + treated_node + for treated_node in treated_group + if find_node_in_time_window( + medrecord, + treated_node, + washout_node, + connected_group=self._treatments_group, + start_days=-washout_days, + end_days=0, + reference=self._washout_period_reference, + ) + } + ) treated_group -= washout_nodes diff --git a/pyproject.toml b/pyproject.toml index a625878..24872d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,10 @@ ignore = [ "ISC002", ] +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["D", "DOC"] +"*/__init__.py" = ["D", "DOC"] + [tool.ruff.lint.pydocstyle] convention = "google"