Skip to content

Commit

Permalink
feat: Adding contingency table print and treatment groups variables r…
Browse files Browse the repository at this point in the history
…enaming (#211)
  • Loading branch information
MarIniOnz authored Sep 18, 2024
1 parent 82a76c2 commit f1e44ad
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 250 deletions.
32 changes: 16 additions & 16 deletions medmodels/treatment_effect/continuous_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

def average_treatment_effect(
medrecord: MedRecord,
treatment_true_set: Set[NodeIndex],
control_true_set: Set[NodeIndex],
treatment_outcome_true_set: Set[NodeIndex],
control_outcome_true_set: Set[NodeIndex],
outcome_group: Group,
outcome_variable: MedRecordAttribute,
reference: Literal["first", "last"] = "last",
Expand Down Expand Up @@ -39,10 +39,10 @@ def average_treatment_effect(
Args:
medrecord (MedRecord): An instance of the MedRecord class containing medical
data.
treatment_true_set (Set[NodeIndex]): A set of node indices representing the
treated group.
control_true_set (Set[NodeIndex]): A set of node indices representing the
control group.
treatment_outcome_true_set (Set[NodeIndex]): A set of node indices representing
the treated group that also have the outcome.
control_outcome_true_set (Set[NodeIndex]): A set of node indices representing
the control group that have the outcome.
outcome_group (Group): The group of nodes that contain the outcome variable.
outcome_variable (MedRecordAttribute): The attribute in the edge that contains
the outcome variable. It must be numeric and continuous.
Expand Down Expand Up @@ -70,7 +70,7 @@ def average_treatment_effect(
reference=reference,
)
][outcome_variable]
for node_index in treatment_true_set
for node_index in treatment_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
Expand All @@ -87,7 +87,7 @@ def average_treatment_effect(
reference=reference,
)
][outcome_variable]
for node_index in control_true_set
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
Expand All @@ -98,8 +98,8 @@ def average_treatment_effect(

def cohens_d(
medrecord: MedRecord,
treatment_true_set: Set[NodeIndex],
control_true_set: Set[NodeIndex],
treatment_outcome_true_set: Set[NodeIndex],
control_outcome_true_set: Set[NodeIndex],
outcome_group: Group,
outcome_variable: MedRecordAttribute,
reference: Literal["first", "last"] = "last",
Expand Down Expand Up @@ -130,10 +130,10 @@ def cohens_d(
Args:
medrecord (MedRecord): An instance of the MedRecord class containing medical
data.
treatment_true_set (Set[NodeIndex]): A set of node indices representing the
treated group.
control_true_set (Set[NodeIndex]): A set of node indices representing the
control group.
treatment_outcome_true_set (Set[NodeIndex]): A set of node indices representing
the treated group that also have the outcome.
control_outcome_true_set (Set[NodeIndex]): A set of node indices representing
the control group that have the outcome.
outcome_group (Group): The group of nodes that contain the outcome variable.
outcome_variable (MedRecordAttribute): The attribute in the edge that contains
the outcome variable. It must be numeric and continuous.
Expand Down Expand Up @@ -164,7 +164,7 @@ def cohens_d(
reference=reference,
)
][outcome_variable]
for node_index in treatment_true_set
for node_index in treatment_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
Expand All @@ -181,7 +181,7 @@ def cohens_d(
reference=reference,
)
][outcome_variable]
for node_index in control_true_set
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
Expand Down
Loading

0 comments on commit f1e44ad

Please sign in to comment.