Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model_exploration step 2 output #155

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 121 additions & 44 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@
# https://github.com/ipums/hlink

import itertools
import logging
import math
import re
from time import perf_counter
from typing import Any
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve, auc
from pyspark.ml import Model, Transformer
import pyspark.sql
from pyspark.sql.functions import count, mean

import hlink.linking.core.threshold as threshold_core
import hlink.linking.core.classifier as classifier_core

from hlink.linking.link_step import LinkStep

logger = logging.getLogger(__name__)


class LinkStepTrainTestModels(LinkStep):
def __init__(self, task):
def __init__(self, task) -> None:
super().__init__(
task,
"train test models",
Expand All @@ -35,7 +42,7 @@ def __init__(self, task):
],
)

def _run(self):
def _run(self) -> None:
training_conf = str(self.task.training_conf)
table_prefix = self.task.table_prefix
config = self.task.link_run.config
Expand All @@ -61,7 +68,15 @@ def _run(self):
splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed)

model_parameters = self._get_model_parameters(config)
for run in model_parameters:

logger.info(
f"There are {len(model_parameters)} sets of model parameters to explore; "
f"each of these has {n_training_iterations} train-test splits to test on"
)
for run_index, run in enumerate(model_parameters, 1):
run_start_info = f"Starting run {run_index} of {len(model_parameters)} with these parameters: {run}"
print(run_start_info)
logger.info(run_start_info)
params = run.copy()
model_type = params.pop("type")

Expand All @@ -80,20 +95,31 @@ def _run(self):
threshold_ratio = False

threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio)
results_dfs = {}
logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")

results_dfs: dict[int, pd.DataFrame] = {}
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

first = True
for training_data, test_data in splits:
for split_index, (training_data, test_data) in enumerate(splits, 1):
split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}"
print(split_start_info)
logger.debug(split_start_info)
training_data.cache()
test_data.cache()

classifier, post_transformer = classifier_core.choose_classifier(
model_type, params, dep_var
)

logger.debug("Training the model on the training data split")
start_train_time = perf_counter()
model = classifier.fit(training_data)
end_train_time = perf_counter()
logger.debug(
f"Successfully trained the model in {end_train_time - start_train_time:.2f}s"
)

predictions_tmp = _get_probability_and_select_pred_columns(
test_data, model, post_transformer, id_a, id_b, dep_var
Expand All @@ -113,7 +139,7 @@ def _run(self):
param_text = np.full(precision.shape, f"{model_type}_{params}")

pr_auc = auc(recall, precision)
print(f"Area under PR curve: {pr_auc}")
print(f"The area under the precision-recall curve is {pr_auc}")

if first:
prc = pd.DataFrame(
Expand All @@ -134,18 +160,24 @@ def _run(self):
first = False

i = 0
for at, tr in threshold_matrix:
for threshold_index, (alpha_threshold, threshold_ratio) in enumerate(
threshold_matrix, 1
):
logger.debug(
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{alpha_threshold=} and {threshold_ratio=}"
)
predictions = threshold_core.predict_using_thresholds(
predictions_tmp,
at,
tr,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
predict_train_tmp,
at,
tr,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)
Expand All @@ -157,8 +189,8 @@ def _run(self):
model,
results_dfs[i],
otd_data,
at,
tr,
alpha_threshold,
threshold_ratio,
pr_auc,
)
i += 1
Expand All @@ -175,7 +207,19 @@ def _run(self):
self._save_otd_data(otd_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")

def _get_splits(self, prepped_data, id_a, n_training_iterations, seed):
def _get_splits(
self,
prepped_data: pyspark.sql.DataFrame,
id_a: str,
n_training_iterations: int,
seed: int,
) -> list[list[pyspark.sql.DataFrame]]:
"""
Get a list of random splits of the prepped_data into two DataFrames.
There are n_training_iterations elements in the list. Each element is
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is super helpful.

itself a list of two DataFrames which are the splits of prepped_data.
The split DataFrames are roughly equal in size.
"""
if self.task.link_run.config[f"{self.task.training_conf}"].get(
"split_by_id_a", False
):
Expand All @@ -200,7 +244,7 @@ def _get_splits(self, prepped_data, id_a, n_training_iterations, seed):

return splits

def _custom_param_grid_builder(self, conf):
def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
print("Building param grid for models")
given_parameters = conf[f"{self.task.training_conf}"]["model_parameters"]
new_params = []
Expand Down Expand Up @@ -231,19 +275,18 @@ def _custom_param_grid_builder(self, conf):

def _capture_results(
self,
predictions,
predict_train,
dep_var,
model,
results_df,
otd_data,
at,
tr,
pr_auc,
):
predictions: pyspark.sql.DataFrame,
predict_train: pyspark.sql.DataFrame,
dep_var: str,
model: Model,
results_df: pd.DataFrame,
otd_data: dict[str, Any] | None,
alpha_threshold: float,
threshold_ratio: float,
pr_auc: float,
) -> pd.DataFrame:
table_prefix = self.task.table_prefix

print("Evaluating model performance...")
# write to sql tables for testing
predictions.createOrReplaceTempView(f"{table_prefix}predictions")
predict_train.createOrReplaceTempView(f"{table_prefix}predict_train")
Expand Down Expand Up @@ -278,13 +321,13 @@ def _capture_results(
"test_mcc": [test_mcc],
"train_mcc": [train_mcc],
"model_id": [model],
"alpha_threshold": [at],
"threshold_ratio": [tr],
"alpha_threshold": [alpha_threshold],
"threshold_ratio": [threshold_ratio],
},
)
return pd.concat([results_df, new_results], ignore_index=True)

def _get_model_parameters(self, conf):
def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
training_conf = str(self.task.training_conf)

model_parameters = conf[training_conf]["model_parameters"]
Expand All @@ -296,7 +339,9 @@ def _get_model_parameters(self, conf):
)
return model_parameters

def _save_training_results(self, desc_df, spark):
def _save_training_results(
self, desc_df: pd.DataFrame, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if desc_df.empty:
Expand All @@ -310,7 +355,9 @@ def _save_training_results(self, desc_df, spark):
f"Training results saved to Spark table '{table_prefix}training_results'."
)

def _prepare_otd_table(self, spark, df, id_a, id_b):
def _prepare_otd_table(
self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str
) -> pyspark.sql.DataFrame:
spark_df = spark.createDataFrame(df)
counted = (
spark_df.groupby(id_a, id_b)
Expand All @@ -323,7 +370,9 @@ def _prepare_otd_table(self, spark, df, id_a, id_b):
)
return counted

def _save_otd_data(self, otd_data, spark):
def _save_otd_data(
self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if otd_data is None:
Expand Down Expand Up @@ -379,7 +428,7 @@ def _save_otd_data(self, otd_data, spark):
else:
print("There were no true negatives recorded.")

def _create_otd_data(self, id_a, id_b):
def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None:
"""Output Suspicous Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify"""
training_conf = str(self.task.training_conf)
config = self.task.link_run.config
Expand All @@ -400,7 +449,12 @@ def _create_otd_data(self, id_a, id_b):
return None


def _calc_mcc(TP, TN, FP, FN):
def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float:
"""
Given the counts of true positives (TP), true negatives (TN), false
positives (FP), and false negatives (FN) for a model run, compute the
Matthews Correlation Coefficient (MCC).
"""
if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0:
mcc = ((TP * TN) - (FP * FN)) / (
math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
Expand All @@ -410,7 +464,9 @@ def _calc_mcc(TP, TN, FP, FN):
return mcc


def _calc_threshold_matrix(alpha_threshold, threshold_ratio):
def _calc_threshold_matrix(
alpha_threshold: float | list[float], threshold_ratio: float | list[float]
) -> list[list[float]]:
if alpha_threshold and type(alpha_threshold) != list:
alpha_threshold = [alpha_threshold]

Expand All @@ -426,8 +482,13 @@ def _calc_threshold_matrix(alpha_threshold, threshold_ratio):


def _get_probability_and_select_pred_columns(
pred_df, model, post_transformer, id_a, id_b, dep_var
):
pred_df: pyspark.sql.DataFrame,
model: Model,
post_transformer: Transformer,
id_a: str,
id_b: str,
dep_var: str,
) -> pyspark.sql.DataFrame:
all_prediction_cols = set(
[
f"{id_a}",
Expand All @@ -446,7 +507,9 @@ def _get_probability_and_select_pred_columns(
return required_col_df


def _get_confusion_matrix(predictions, dep_var, otd_data):
def _get_confusion_matrix(
predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None
) -> tuple[int, int, int, int]:
TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1))
TP_count = TP.count()

Expand Down Expand Up @@ -486,7 +549,16 @@ def _get_confusion_matrix(predictions, dep_var, otd_data):
return TP_count, FP_count, FN_count, TN_count


def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count):
def _get_aggregate_metrics(
TP_count: int, FP_count: int, FN_count: int, TN_count: int
) -> tuple[float, float, float]:
"""
Given the counts of true positives, false positivies, false negatives, and
true negatives for a model run, compute several metrics to evaluate the
model's quality.

Return a tuple of (precision, recall, Matthews Correlation Coefficient).
"""
if (TP_count + FP_count) == 0:
precision = np.nan
else:
Expand All @@ -499,7 +571,7 @@ def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count):
return precision, recall, mcc


def _create_results_df():
def _create_results_df() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"precision_test",
Expand All @@ -516,7 +588,12 @@ def _create_results_df():
)


def _append_results(desc_df, results_df, model_type, params):
def _append_results(
desc_df: pd.DataFrame,
results_df: pd.DataFrame,
model_type: str,
params: dict[str, Any],
) -> pd.DataFrame:
# run.pop("type")
print(results_df)

Expand Down Expand Up @@ -548,7 +625,7 @@ def _append_results(desc_df, results_df, model_type, params):
return desc_df


def _print_desc_df(desc_df):
def _print_desc_df(desc_df: pd.DataFrame) -> None:
pd.set_option("display.max_colwidth", None)
print(
desc_df.drop(
Expand All @@ -564,7 +641,7 @@ def _print_desc_df(desc_df):
print("\n")


def _load_desc_df_params(desc_df):
def _load_desc_df_params(desc_df: pd.DataFrame) -> pd.DataFrame:
params = [
"maxDepth",
"numTrees",
Expand All @@ -591,7 +668,7 @@ def _load_desc_df_params(desc_df):
return desc_df


def _create_desc_df():
def _create_desc_df() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"model",
Expand Down