Skip to content

Commit

Permalink
refactor: disable docstring linting for tests and __init__ (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Oct 17, 2024
1 parent 75c7550 commit 73370d6
Show file tree
Hide file tree
Showing 14 changed files with 481 additions and 395 deletions.
10 changes: 6 additions & 4 deletions medmodels/medrecord/_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,12 @@ def prettify_table(
"-" * (sum(lengths) + len(lengths)),
]

table.extend([
"".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths)))
for row in rows
])
table.extend(
[
"".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths)))
for row in rows
]
)

table.append("-" * (sum(lengths) + len(lengths)))

Expand Down
20 changes: 12 additions & 8 deletions medmodels/medrecord/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,12 @@ def _convert_node(
else None,
)

return AttributesSchema({
x: _convert_node(self._group_schema.nodes[x])
for x in self._group_schema.nodes
})
return AttributesSchema(
{
x: _convert_node(self._group_schema.nodes[x])
for x in self._group_schema.nodes
}
)

@property
def edges(self) -> AttributesSchema:
Expand All @@ -352,10 +354,12 @@ def _convert_edge(
else None,
)

return AttributesSchema({
x: _convert_edge(self._group_schema.edges[x])
for x in self._group_schema.edges
})
return AttributesSchema(
{
x: _convert_edge(self._group_schema.edges[x])
for x in self._group_schema.edges
}
)

@property
def strict(self) -> Optional[bool]:
Expand Down
96 changes: 57 additions & 39 deletions medmodels/medrecord/tests/test_medrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,41 @@ def create_edges() -> List[Tuple[NodeIndex, NodeIndex, Attributes]]:


def create_pandas_nodes_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"index": ["0", "1"],
"attribute": [1, 2],
})
return pd.DataFrame(
{
"index": ["0", "1"],
"attribute": [1, 2],
}
)


def create_second_pandas_nodes_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"index": ["2", "3"],
"attribute": [2, 3],
})
return pd.DataFrame(
{
"index": ["2", "3"],
"attribute": [2, 3],
}
)


def create_pandas_edges_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [1, 2],
})
return pd.DataFrame(
{
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [1, 2],
}
)


def create_second_pandas_edges_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [2, 3],
})
return pd.DataFrame(
{
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [2, 3],
}
)


def create_medrecord() -> MedRecord:
Expand Down Expand Up @@ -277,11 +285,13 @@ def test_schema(self) -> None:

medrecord.add_edges_to_group("group", edge_index)

edge_index = medrecord.add_edges((
"0",
"1",
{"attribute": 1, "attribute2": "1"},
))
edge_index = medrecord.add_edges(
(
"0",
"1",
{"attribute": 1, "attribute2": "1"},
)
)

with self.assertRaises(ValueError):
medrecord.add_edges_to_group("group", edge_index)
Expand Down Expand Up @@ -602,10 +612,12 @@ def test_add_nodes(self) -> None:

assert medrecord.node_count() == 0

medrecord.add_nodes([
(create_pandas_nodes_dataframe(), "index"),
(create_second_pandas_nodes_dataframe(), "index"),
])
medrecord.add_nodes(
[
(create_pandas_nodes_dataframe(), "index"),
(create_second_pandas_nodes_dataframe(), "index"),
]
)

assert medrecord.node_count() == 4

Expand Down Expand Up @@ -649,10 +661,12 @@ def test_add_nodes(self) -> None:

assert medrecord.node_count() == 0

medrecord.add_nodes([
(nodes, "index"),
(second_nodes, "index"),
])
medrecord.add_nodes(
[
(nodes, "index"),
(second_nodes, "index"),
]
)

assert medrecord.node_count() == 4

Expand Down Expand Up @@ -908,10 +922,12 @@ def test_add_edges(self) -> None:

assert medrecord.edge_count() == 0

medrecord.add_edges([
(create_pandas_edges_dataframe(), "source", "target"),
(create_second_pandas_edges_dataframe(), "source", "target"),
])
medrecord.add_edges(
[
(create_pandas_edges_dataframe(), "source", "target"),
(create_second_pandas_edges_dataframe(), "source", "target"),
]
)

assert medrecord.edge_count() == 4

Expand Down Expand Up @@ -942,10 +958,12 @@ def test_add_edges(self) -> None:

second_edges = pl.from_pandas(create_second_pandas_edges_dataframe())

medrecord.add_edges([
(edges, "source", "target"),
(second_edges, "source", "target"),
])
medrecord.add_edges(
[
(edges, "source", "target"),
(second_edges, "source", "target"),
]
)

assert medrecord.edge_count() == 4

Expand Down
32 changes: 19 additions & 13 deletions medmodels/medrecord/tests/test_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,31 @@


def create_medrecord():
patients = pd.DataFrame({
"index": ["P1", "P2", "P3"],
"age": [20, 30, 70],
})
patients = pd.DataFrame(
{
"index": ["P1", "P2", "P3"],
"age": [20, 30, 70],
}
)

diagnosis = pd.DataFrame({"index": ["D1", "D2"]})

prescriptions = pd.DataFrame({
"index": ["M1", "M2", "M3"],
"ATC": ["B01AF01", "B01AA03", np.nan],
})
prescriptions = pd.DataFrame(
{
"index": ["M1", "M2", "M3"],
"ATC": ["B01AF01", "B01AA03", np.nan],
}
)

nodes = [patients, diagnosis, prescriptions]

edges = pd.DataFrame({
"source": ["D1", "M1", "D1"],
"target": ["P1", "P2", "P3"],
"time": ["2000-01-01", "1999-10-15", "1999-12-15"],
})
edges = pd.DataFrame(
{
"source": ["D1", "M1", "D1"],
"target": ["P1", "P2", "P3"],
"time": ["2000-01-01", "1999-10-15", "1999-12-15"],
}
)

edges.time = pd.to_datetime(edges.time)

Expand Down
20 changes: 11 additions & 9 deletions medmodels/medrecord/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ def setUp(self) -> None:
self.schema = create_medrecord().schema

def test_groups(self) -> None:
assert sorted([
"diagnosis",
"drug",
"patient_diagnosis",
"patient_drug",
"patient_procedure",
"patient",
"procedure",
]) == sorted(self.schema.groups)
assert sorted(
[
"diagnosis",
"drug",
"patient_diagnosis",
"patient_drug",
"patient_procedure",
"patient",
"procedure",
]
) == sorted(self.schema.groups)

def test_group(self) -> None:
assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance]
Expand Down
104 changes: 56 additions & 48 deletions medmodels/treatment_effect/continuous_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,38 @@ def average_treatment_effect(
Raises:
ValueError: If the outcome variable is not numeric.
"""
treated_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
])
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)

control_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
])
control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)
Expand Down Expand Up @@ -151,34 +155,38 @@ def cohens_d(
Raises:
ValueError: If the outcome variable is not numeric.
"""
treated_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
])
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)

control_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
])
control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)
Expand Down
Loading

0 comments on commit 73370d6

Please sign in to comment.