Skip to content

Commit

Permalink
[#154] Add type hints and a few doc comments to model_exploration.lin…
Browse files Browse the repository at this point in the history
…k_step_train_test_models
  • Loading branch information
riley-harper committed Oct 8, 2024
1 parent 5171c25 commit a330a66
Showing 1 changed file with 80 additions and 31 deletions.
111 changes: 80 additions & 31 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import itertools
import math
import re
from typing import Any
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve, auc
from pyspark.ml import Model, Transformer
import pyspark.sql
from pyspark.sql.functions import count, mean

import hlink.linking.core.threshold as threshold_core
Expand All @@ -18,7 +21,7 @@


class LinkStepTrainTestModels(LinkStep):
def __init__(self, task):
def __init__(self, task) -> None:
super().__init__(
task,
"train test models",
Expand All @@ -35,7 +38,7 @@ def __init__(self, task):
],
)

def _run(self):
def _run(self) -> None:
training_conf = str(self.task.training_conf)
table_prefix = self.task.table_prefix
config = self.task.link_run.config
Expand Down Expand Up @@ -80,7 +83,7 @@ def _run(self):
threshold_ratio = False

threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio)
results_dfs = {}
results_dfs: dict[int, pd.DataFrame] = {}
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

Expand Down Expand Up @@ -175,7 +178,19 @@ def _run(self):
self._save_otd_data(otd_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")

def _get_splits(self, prepped_data, id_a, n_training_iterations, seed):
def _get_splits(
self,
prepped_data: pyspark.sql.DataFrame,
id_a: str,
n_training_iterations: int,
seed: int,
) -> list[list[pyspark.sql.DataFrame]]:
"""
Get a list of random splits of the prepped_data into two DataFrames.
There are n_training_iterations elements in the list. Each element is
itself a list of two DataFrames which are the splits of prepped_data.
The split DataFrames are roughly equal in size.
"""
if self.task.link_run.config[f"{self.task.training_conf}"].get(
"split_by_id_a", False
):
Expand All @@ -200,7 +215,7 @@ def _get_splits(self, prepped_data, id_a, n_training_iterations, seed):

return splits

def _custom_param_grid_builder(self, conf):
def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
print("Building param grid for models")
given_parameters = conf[f"{self.task.training_conf}"]["model_parameters"]
new_params = []
Expand Down Expand Up @@ -231,16 +246,16 @@ def _custom_param_grid_builder(self, conf):

def _capture_results(
self,
predictions,
predict_train,
dep_var,
model,
results_df,
otd_data,
at,
tr,
pr_auc,
):
predictions: pyspark.sql.DataFrame,
predict_train: pyspark.sql.DataFrame,
dep_var: str,
model: Model,
results_df: pd.DataFrame,
otd_data: dict[str, Any] | None,
at: float,
tr: float,
pr_auc: float,
) -> pd.DataFrame:
table_prefix = self.task.table_prefix

print("Evaluating model performance...")
Expand Down Expand Up @@ -284,7 +299,7 @@ def _capture_results(
)
return pd.concat([results_df, new_results], ignore_index=True)

def _get_model_parameters(self, conf):
def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
training_conf = str(self.task.training_conf)

model_parameters = conf[training_conf]["model_parameters"]
Expand All @@ -296,7 +311,9 @@ def _get_model_parameters(self, conf):
)
return model_parameters

def _save_training_results(self, desc_df, spark):
def _save_training_results(
self, desc_df: pd.DataFrame, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if desc_df.empty:
Expand All @@ -310,7 +327,9 @@ def _save_training_results(self, desc_df, spark):
f"Training results saved to Spark table '{table_prefix}training_results'."
)

def _prepare_otd_table(self, spark, df, id_a, id_b):
def _prepare_otd_table(
self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str
) -> pyspark.sql.DataFrame:
spark_df = spark.createDataFrame(df)
counted = (
spark_df.groupby(id_a, id_b)
Expand All @@ -323,7 +342,9 @@ def _prepare_otd_table(self, spark, df, id_a, id_b):
)
return counted

def _save_otd_data(self, otd_data, spark):
def _save_otd_data(
self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if otd_data is None:
Expand Down Expand Up @@ -379,7 +400,7 @@ def _save_otd_data(self, otd_data, spark):
else:
print("There were no true negatives recorded.")

def _create_otd_data(self, id_a, id_b):
def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None:
"""Output Suspicous Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify"""
training_conf = str(self.task.training_conf)
config = self.task.link_run.config
Expand All @@ -400,7 +421,12 @@ def _create_otd_data(self, id_a, id_b):
return None


def _calc_mcc(TP, TN, FP, FN):
def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float:
"""
Given the counts of true positives (TP), true negatives (TN), false
positives (FP), and false negatives (FN) for a model run, compute the
Matthews Correlation Coefficient (MCC).
"""
if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0:
mcc = ((TP * TN) - (FP * FN)) / (
math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
Expand All @@ -410,7 +436,9 @@ def _calc_mcc(TP, TN, FP, FN):
return mcc


def _calc_threshold_matrix(alpha_threshold, threshold_ratio):
def _calc_threshold_matrix(
alpha_threshold: float | list[float], threshold_ratio: float | list[float]
) -> list[list[float]]:
if alpha_threshold and type(alpha_threshold) != list:
alpha_threshold = [alpha_threshold]

Expand All @@ -426,8 +454,13 @@ def _calc_threshold_matrix(alpha_threshold, threshold_ratio):


def _get_probability_and_select_pred_columns(
pred_df, model, post_transformer, id_a, id_b, dep_var
):
pred_df: pyspark.sql.DataFrame,
model: Model,
post_transformer: Transformer,
id_a: str,
id_b: str,
dep_var: str,
) -> pyspark.sql.DataFrame:
all_prediction_cols = set(
[
f"{id_a}",
Expand All @@ -446,7 +479,9 @@ def _get_probability_and_select_pred_columns(
return required_col_df


def _get_confusion_matrix(predictions, dep_var, otd_data):
def _get_confusion_matrix(
predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None
) -> tuple[int, int, int, int]:
TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1))
TP_count = TP.count()

Expand Down Expand Up @@ -486,7 +521,16 @@ def _get_confusion_matrix(predictions, dep_var, otd_data):
return TP_count, FP_count, FN_count, TN_count


def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count):
def _get_aggregate_metrics(
TP_count: int, FP_count: int, FN_count: int, TN_count: int
) -> tuple[float, float, float]:
"""
Given the counts of true positives, false positivies, false negatives, and
true negatives for a model run, compute several metrics to evaluate the
model's quality.
Return a tuple of (precision, recall, Matthews Correlation Coefficient).
"""
if (TP_count + FP_count) == 0:
precision = np.nan
else:
Expand All @@ -499,7 +543,7 @@ def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count):
return precision, recall, mcc


def _create_results_df():
def _create_results_df() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"precision_test",
Expand All @@ -516,7 +560,12 @@ def _create_results_df():
)


def _append_results(desc_df, results_df, model_type, params):
def _append_results(
desc_df: pd.DataFrame,
results_df: pd.DataFrame,
model_type: str,
params: dict[str, Any],
) -> pd.DataFrame:
# run.pop("type")
print(results_df)

Expand Down Expand Up @@ -548,7 +597,7 @@ def _append_results(desc_df, results_df, model_type, params):
return desc_df


def _print_desc_df(desc_df):
def _print_desc_df(desc_df: pd.DataFrame) -> None:
pd.set_option("display.max_colwidth", None)
print(
desc_df.drop(
Expand All @@ -564,7 +613,7 @@ def _print_desc_df(desc_df):
print("\n")


def _load_desc_df_params(desc_df):
def _load_desc_df_params(desc_df: pd.DataFrame) -> pd.DataFrame:
params = [
"maxDepth",
"numTrees",
Expand All @@ -591,7 +640,7 @@ def _load_desc_df_params(desc_df):
return desc_df


def _create_desc_df():
def _create_desc_df() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"model",
Expand Down

0 comments on commit a330a66

Please sign in to comment.