diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index a7c79ec..8e391b8 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -4,11 +4,16 @@ # https://github.com/ipums/hlink import itertools +import logging import math import re +from time import perf_counter +from typing import Any import numpy as np import pandas as pd from sklearn.metrics import precision_recall_curve, auc +from pyspark.ml import Model, Transformer +import pyspark.sql from pyspark.sql.functions import count, mean import hlink.linking.core.threshold as threshold_core @@ -16,9 +21,11 @@ from hlink.linking.link_step import LinkStep +logger = logging.getLogger(__name__) + class LinkStepTrainTestModels(LinkStep): - def __init__(self, task): + def __init__(self, task) -> None: super().__init__( task, "train test models", @@ -35,7 +42,7 @@ def __init__(self, task): ], ) - def _run(self): + def _run(self) -> None: training_conf = str(self.task.training_conf) table_prefix = self.task.table_prefix config = self.task.link_run.config @@ -61,7 +68,15 @@ def _run(self): splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed) model_parameters = self._get_model_parameters(config) - for run in model_parameters: + + logger.info( + f"There are {len(model_parameters)} sets of model parameters to explore; " + f"each of these has {n_training_iterations} train-test splits to test on" + ) + for run_index, run in enumerate(model_parameters, 1): + run_start_info = f"Starting run {run_index} of {len(model_parameters)} with these parameters: {run}" + print(run_start_info) + logger.info(run_start_info) params = run.copy() model_type = params.pop("type") @@ -80,12 +95,17 @@ def _run(self): threshold_ratio = False threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio) - results_dfs = {} + logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries") + + results_dfs: dict[int, pd.DataFrame] = {} for i in range(len(threshold_matrix)): results_dfs[i] = _create_results_df() first = True - for training_data, test_data in splits: + for split_index, (training_data, test_data) in enumerate(splits, 1): + split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}" + print(split_start_info) + logger.debug(split_start_info) training_data.cache() test_data.cache() @@ -93,7 +113,13 @@ def _run(self): model_type, params, dep_var ) + logger.debug("Training the model on the training data split") + start_train_time = perf_counter() model = classifier.fit(training_data) + end_train_time = perf_counter() + logger.debug( + f"Successfully trained the model in {end_train_time - start_train_time:.2f}s" + ) predictions_tmp = _get_probability_and_select_pred_columns( test_data, model, post_transformer, id_a, id_b, dep_var @@ -113,7 +139,7 @@ def _run(self): param_text = np.full(precision.shape, f"{model_type}_{params}") pr_auc = auc(recall, precision) - print(f"Area under PR curve: {pr_auc}") + print(f"The area under the precision-recall curve is {pr_auc}") if first: prc = pd.DataFrame( @@ -134,18 +160,24 @@ def _run(self): first = False i = 0 - for at, tr in threshold_matrix: + for threshold_index, (alpha_threshold, threshold_ratio) in enumerate( + threshold_matrix, 1 + ): + logger.debug( + f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: " + f"{alpha_threshold=} and {threshold_ratio=}" + ) predictions = threshold_core.predict_using_thresholds( predictions_tmp, - at, - tr, + alpha_threshold, + threshold_ratio, config[training_conf], config["id_column"], ) predict_train = threshold_core.predict_using_thresholds( predict_train_tmp, - at, - tr, + alpha_threshold, + threshold_ratio, config[training_conf], config["id_column"], ) @@ -157,8 +189,8 @@ def _run(self): model, results_dfs[i], otd_data, - at, - tr, + alpha_threshold, + threshold_ratio, pr_auc, ) i += 1 @@ -175,7 +207,19 @@ def _run(self): self._save_otd_data(otd_data, self.task.spark) self.task.spark.sql("set spark.sql.shuffle.partitions=200") - def _get_splits(self, prepped_data, id_a, n_training_iterations, seed): + def _get_splits( + self, + prepped_data: pyspark.sql.DataFrame, + id_a: str, + n_training_iterations: int, + seed: int, + ) -> list[list[pyspark.sql.DataFrame]]: + """ + Get a list of random splits of the prepped_data into two DataFrames. + There are n_training_iterations elements in the list. Each element is + itself a list of two DataFrames which are the splits of prepped_data. + The split DataFrames are roughly equal in size. + """ if self.task.link_run.config[f"{self.task.training_conf}"].get( "split_by_id_a", False ): @@ -200,7 +244,7 @@ def _get_splits(self, prepped_data, id_a, n_training_iterations, seed): return splits - def _custom_param_grid_builder(self, conf): + def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]: print("Building param grid for models") given_parameters = conf[f"{self.task.training_conf}"]["model_parameters"] new_params = [] @@ -231,19 +275,18 @@ def _custom_param_grid_builder(self, conf): def _capture_results( self, - predictions, - predict_train, - dep_var, - model, - results_df, - otd_data, - at, - tr, - pr_auc, - ): + predictions: pyspark.sql.DataFrame, + predict_train: pyspark.sql.DataFrame, + dep_var: str, + model: Model, + results_df: pd.DataFrame, + otd_data: dict[str, Any] | None, + alpha_threshold: float, + threshold_ratio: float, + pr_auc: float, + ) -> pd.DataFrame: table_prefix = self.task.table_prefix - print("Evaluating model performance...") # write to sql tables for testing predictions.createOrReplaceTempView(f"{table_prefix}predictions") predict_train.createOrReplaceTempView(f"{table_prefix}predict_train") @@ -278,13 +321,13 @@ def _capture_results( "test_mcc": [test_mcc], "train_mcc": [train_mcc], "model_id": [model], - "alpha_threshold": [at], - "threshold_ratio": [tr], + "alpha_threshold": [alpha_threshold], + "threshold_ratio": [threshold_ratio], }, ) return pd.concat([results_df, new_results], ignore_index=True) - def _get_model_parameters(self, conf): + def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]: training_conf = str(self.task.training_conf) model_parameters = conf[training_conf]["model_parameters"] @@ -296,7 +339,9 @@ def _get_model_parameters(self, conf): ) return model_parameters - def _save_training_results(self, desc_df, spark): + def _save_training_results( + self, desc_df: pd.DataFrame, spark: pyspark.sql.SparkSession + ) -> None: table_prefix = self.task.table_prefix if desc_df.empty: @@ -310,7 +355,9 @@ def _save_training_results(self, desc_df, spark): f"Training results saved to Spark table '{table_prefix}training_results'." ) - def _prepare_otd_table(self, spark, df, id_a, id_b): + def _prepare_otd_table( + self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str + ) -> pyspark.sql.DataFrame: spark_df = spark.createDataFrame(df) counted = ( spark_df.groupby(id_a, id_b) @@ -323,7 +370,9 @@ def _prepare_otd_table(self, spark, df, id_a, id_b): ) return counted - def _save_otd_data(self, otd_data, spark): + def _save_otd_data( + self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession + ) -> None: table_prefix = self.task.table_prefix if otd_data is None: @@ -379,7 +428,7 @@ def _save_otd_data(self, otd_data, spark): else: print("There were no true negatives recorded.") - def _create_otd_data(self, id_a, id_b): + def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None: """Output Suspicous Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify""" training_conf = str(self.task.training_conf) config = self.task.link_run.config @@ -400,7 +449,12 @@ def _create_otd_data(self, id_a, id_b): return None -def _calc_mcc(TP, TN, FP, FN): +def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float: + """ + Given the counts of true positives (TP), true negatives (TN), false + positives (FP), and false negatives (FN) for a model run, compute the + Matthews Correlation Coefficient (MCC). + """ if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0: mcc = ((TP * TN) - (FP * FN)) / ( math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) @@ -410,7 +464,9 @@ def _calc_mcc(TP, TN, FP, FN): return mcc -def _calc_threshold_matrix(alpha_threshold, threshold_ratio): +def _calc_threshold_matrix( + alpha_threshold: float | list[float], threshold_ratio: float | list[float] +) -> list[list[float]]: if alpha_threshold and type(alpha_threshold) != list: alpha_threshold = [alpha_threshold] @@ -426,8 +482,13 @@ def _calc_threshold_matrix(alpha_threshold, threshold_ratio): def _get_probability_and_select_pred_columns( - pred_df, model, post_transformer, id_a, id_b, dep_var -): + pred_df: pyspark.sql.DataFrame, + model: Model, + post_transformer: Transformer, + id_a: str, + id_b: str, + dep_var: str, +) -> pyspark.sql.DataFrame: all_prediction_cols = set( [ f"{id_a}", @@ -446,7 +507,9 @@ def _get_probability_and_select_pred_columns( return required_col_df -def _get_confusion_matrix(predictions, dep_var, otd_data): +def _get_confusion_matrix( + predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None +) -> tuple[int, int, int, int]: TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1)) TP_count = TP.count() @@ -486,7 +549,16 @@ def _get_confusion_matrix(predictions, dep_var, otd_data): return TP_count, FP_count, FN_count, TN_count -def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count): +def _get_aggregate_metrics( + TP_count: int, FP_count: int, FN_count: int, TN_count: int +) -> tuple[float, float, float]: + """ + Given the counts of true positives, false positivies, false negatives, and + true negatives for a model run, compute several metrics to evaluate the + model's quality. + + Return a tuple of (precision, recall, Matthews Correlation Coefficient). + """ if (TP_count + FP_count) == 0: precision = np.nan else: @@ -499,7 +571,7 @@ def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count): return precision, recall, mcc -def _create_results_df(): +def _create_results_df() -> pd.DataFrame: return pd.DataFrame( columns=[ "precision_test", @@ -516,7 +588,12 @@ def _create_results_df(): ) -def _append_results(desc_df, results_df, model_type, params): +def _append_results( + desc_df: pd.DataFrame, + results_df: pd.DataFrame, + model_type: str, + params: dict[str, Any], +) -> pd.DataFrame: # run.pop("type") print(results_df) @@ -548,7 +625,7 @@ def _append_results(desc_df, results_df, model_type, params): return desc_df -def _print_desc_df(desc_df): +def _print_desc_df(desc_df: pd.DataFrame) -> None: pd.set_option("display.max_colwidth", None) print( desc_df.drop( @@ -564,7 +641,7 @@ def _print_desc_df(desc_df): print("\n") -def _load_desc_df_params(desc_df): +def _load_desc_df_params(desc_df: pd.DataFrame) -> pd.DataFrame: params = [ "maxDepth", "numTrees", @@ -591,7 +668,7 @@ def _load_desc_df_params(desc_df): return desc_df -def _create_desc_df(): +def _create_desc_df() -> pd.DataFrame: return pd.DataFrame( columns=[ "model",