Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: disable docstring linting for tests and __init__ #235

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions medmodels/medrecord/_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
)


def extract_attribute_summary(

Check failure on line 21 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D417)

medmodels/medrecord/_overview.py:21:5: D417 Missing argument description in the docstring for `extract_attribute_summary`: `attribute_dictionary`
attribute_dictionary: Union[

Check failure on line 22 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

medmodels/medrecord/_overview.py:22:27: FA100 Add `from __future__ import annotations` to simplify `typing.Union`
Dict[EdgeIndex, Attributes], Dict[NodeIndex, Attributes]
],
schema: Optional[AttributesSchema] = None,

Check failure on line 25 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

medmodels/medrecord/_overview.py:25:13: FA100 Add `from __future__ import annotations` to simplify `typing.Optional`
) -> Dict[
MedRecordAttribute,
Union[TemporalAttributeInfo, NumericAttributeInfo, StringAttributeInfo],

Check failure on line 28 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

medmodels/medrecord/_overview.py:28:5: FA100 Add `from __future__ import annotations` to simplify `typing.Union`
]:
"""Extracts a summary from a node or edge attribute dictionary.

Args:
attribute_dict (Union[Dict[EdgeIndex, Attributes], Dict[NodeIndex, Attributes]]):

Check failure on line 33 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W505)

medmodels/medrecord/_overview.py:33:89: W505 Doc line too long (89 > 88)
Edges or Nodes and their attributes and values.
schema (Optional[AttributesSchema], optional): Attribute Schema for the group
nodes or edges. Defaults to None.
Expand Down Expand Up @@ -210,7 +210,7 @@
row[2] = str(attribute) if first_line else ""

# displaying info based on the type
if "values" in info.keys():

Check failure on line 213 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (SIM118)

medmodels/medrecord/_overview.py:213:20: SIM118 Use `key in dict` instead of `key in dict.keys()`
row[3] = info[key]
else:
if isinstance(info[key], float):
Expand All @@ -232,10 +232,12 @@
"-" * (sum(lengths) + len(lengths)),
]

table.extend([
"".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths)))
for row in rows
])
table.extend(
[
"".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths)))
for row in rows
]
)

table.append("-" * (sum(lengths) + len(lengths)))

Expand Down
20 changes: 12 additions & 8 deletions medmodels/medrecord/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,12 @@ def _convert_node(
else None,
)

return AttributesSchema({
x: _convert_node(self._group_schema.nodes[x])
for x in self._group_schema.nodes
})
return AttributesSchema(
{
x: _convert_node(self._group_schema.nodes[x])
for x in self._group_schema.nodes
}
)

@property
def edges(self) -> AttributesSchema:
Expand All @@ -352,10 +354,12 @@ def _convert_edge(
else None,
)

return AttributesSchema({
x: _convert_edge(self._group_schema.edges[x])
for x in self._group_schema.edges
})
return AttributesSchema(
{
x: _convert_edge(self._group_schema.edges[x])
for x in self._group_schema.edges
}
)

@property
def strict(self) -> Optional[bool]:
Expand Down
96 changes: 57 additions & 39 deletions medmodels/medrecord/tests/test_medrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,41 @@ def create_edges() -> List[Tuple[NodeIndex, NodeIndex, Attributes]]:


def create_pandas_nodes_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"index": ["0", "1"],
"attribute": [1, 2],
})
return pd.DataFrame(
{
"index": ["0", "1"],
"attribute": [1, 2],
}
)


def create_second_pandas_nodes_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"index": ["2", "3"],
"attribute": [2, 3],
})
return pd.DataFrame(
{
"index": ["2", "3"],
"attribute": [2, 3],
}
)


def create_pandas_edges_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [1, 2],
})
return pd.DataFrame(
{
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [1, 2],
}
)


def create_second_pandas_edges_dataframe() -> pd.DataFrame:
return pd.DataFrame({
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [2, 3],
})
return pd.DataFrame(
{
"source": ["0", "1"],
"target": ["1", "0"],
"attribute": [2, 3],
}
)


def create_medrecord() -> MedRecord:
Expand Down Expand Up @@ -277,11 +285,13 @@ def test_schema(self) -> None:

medrecord.add_edges_to_group("group", edge_index)

edge_index = medrecord.add_edges((
"0",
"1",
{"attribute": 1, "attribute2": "1"},
))
edge_index = medrecord.add_edges(
(
"0",
"1",
{"attribute": 1, "attribute2": "1"},
)
)

with self.assertRaises(ValueError):
medrecord.add_edges_to_group("group", edge_index)
Expand Down Expand Up @@ -602,10 +612,12 @@ def test_add_nodes(self) -> None:

assert medrecord.node_count() == 0

medrecord.add_nodes([
(create_pandas_nodes_dataframe(), "index"),
(create_second_pandas_nodes_dataframe(), "index"),
])
medrecord.add_nodes(
[
(create_pandas_nodes_dataframe(), "index"),
(create_second_pandas_nodes_dataframe(), "index"),
]
)

assert medrecord.node_count() == 4

Expand Down Expand Up @@ -649,10 +661,12 @@ def test_add_nodes(self) -> None:

assert medrecord.node_count() == 0

medrecord.add_nodes([
(nodes, "index"),
(second_nodes, "index"),
])
medrecord.add_nodes(
[
(nodes, "index"),
(second_nodes, "index"),
]
)

assert medrecord.node_count() == 4

Expand Down Expand Up @@ -908,10 +922,12 @@ def test_add_edges(self) -> None:

assert medrecord.edge_count() == 0

medrecord.add_edges([
(create_pandas_edges_dataframe(), "source", "target"),
(create_second_pandas_edges_dataframe(), "source", "target"),
])
medrecord.add_edges(
[
(create_pandas_edges_dataframe(), "source", "target"),
(create_second_pandas_edges_dataframe(), "source", "target"),
]
)

assert medrecord.edge_count() == 4

Expand Down Expand Up @@ -942,10 +958,12 @@ def test_add_edges(self) -> None:

second_edges = pl.from_pandas(create_second_pandas_edges_dataframe())

medrecord.add_edges([
(edges, "source", "target"),
(second_edges, "source", "target"),
])
medrecord.add_edges(
[
(edges, "source", "target"),
(second_edges, "source", "target"),
]
)

assert medrecord.edge_count() == 4

Expand Down
32 changes: 19 additions & 13 deletions medmodels/medrecord/tests/test_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,31 @@


def create_medrecord():
patients = pd.DataFrame({
"index": ["P1", "P2", "P3"],
"age": [20, 30, 70],
})
patients = pd.DataFrame(
{
"index": ["P1", "P2", "P3"],
"age": [20, 30, 70],
}
)

diagnosis = pd.DataFrame({"index": ["D1", "D2"]})

prescriptions = pd.DataFrame({
"index": ["M1", "M2", "M3"],
"ATC": ["B01AF01", "B01AA03", np.nan],
})
prescriptions = pd.DataFrame(
{
"index": ["M1", "M2", "M3"],
"ATC": ["B01AF01", "B01AA03", np.nan],
}
)

nodes = [patients, diagnosis, prescriptions]

edges = pd.DataFrame({
"source": ["D1", "M1", "D1"],
"target": ["P1", "P2", "P3"],
"time": ["2000-01-01", "1999-10-15", "1999-12-15"],
})
edges = pd.DataFrame(
{
"source": ["D1", "M1", "D1"],
"target": ["P1", "P2", "P3"],
"time": ["2000-01-01", "1999-10-15", "1999-12-15"],
}
)

edges.time = pd.to_datetime(edges.time)

Expand Down
20 changes: 11 additions & 9 deletions medmodels/medrecord/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ def setUp(self) -> None:
self.schema = create_medrecord().schema

def test_groups(self) -> None:
assert sorted([
"diagnosis",
"drug",
"patient_diagnosis",
"patient_drug",
"patient_procedure",
"patient",
"procedure",
]) == sorted(self.schema.groups)
assert sorted(
[
"diagnosis",
"drug",
"patient_diagnosis",
"patient_drug",
"patient_procedure",
"patient",
"procedure",
]
) == sorted(self.schema.groups)

def test_group(self) -> None:
assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance]
Expand Down
104 changes: 56 additions & 48 deletions medmodels/treatment_effect/continuous_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,38 @@ def average_treatment_effect(
Raises:
ValueError: If the outcome variable is not numeric.
"""
treated_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
])
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)

control_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
])
control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)
Expand Down Expand Up @@ -151,34 +155,38 @@ def cohens_d(
Raises:
ValueError: If the outcome variable is not numeric.
"""
treated_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
])
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)

control_outcomes = np.array([
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
])
control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
msg = "Outcome variable must be numeric"
raise ValueError(msg)
Expand Down
Loading
Loading