Skip to content

Commit

Permalink
Possibility to create groupKfold
Browse files Browse the repository at this point in the history
  • Loading branch information
AnthonyTedde committed Feb 23, 2020
1 parent 2d1f0d4 commit bd81c5b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
38 changes: 36 additions & 2 deletions R/t_outlier_test.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ t_outlier_test.default <- function(...){
t_outlier_test.formula <- function(x,
data,
method,
group,
k,
predictors = dplyr::everything(),
std_err = 3,
remove = TRUE,
verbose = TRUE,
Expand All @@ -54,6 +57,9 @@ t_outlier_test.formula <- function(x,
t_outlier_test_internal(x = f,
data = data,
method = method,
group = group,
k = k,
predictors = predictors,
std_err = std_err,
remove = remove,
verbose = verbose,
Expand All @@ -68,6 +74,9 @@ t_outlier_test.formula <- function(x,
t_outlier_test.recipe <- function(x,
data,
method,
group,
k,
predictors = dplyr::everything(),
std_err = 3,
remove = TRUE,
verbose = TRUE,
Expand All @@ -76,6 +85,9 @@ t_outlier_test.recipe <- function(x,
t_outlier_test_internal(x = x,
data = data,
method = method,
group = group,
k = k,
predictors = predictors,
std_err = std_err,
remove = remove,
verbose = verbose,
Expand All @@ -89,11 +101,16 @@ t_outlier_test.recipe <- function(x,
t_outlier_test_internal <- function(x,
data,
method,
group = NULL,
k = NULL,
predictors = dplyr::everything(),
std_err = 3,
remove = TRUE,
verbose = TRUE,
...){

data_calib <- data %>% dplyr::select(predictors)

method %>% purrr::map(.f = function(m){

outliers <- rep(T, nrow(data))
Expand All @@ -107,14 +124,31 @@ t_outlier_test_internal <- function(x,
arguments <- c(
list(
x,
data = data[outliers, ]
data = data_calib[outliers, ]
), m, list(...)
)

if(!purrr::is_empty(group)){
tryCatch({

fold_maxsize <- data[[group]] %>% unique %>% length
k <- ifelse(is.numeric(k) && k < fold_maxsize, k, fold_maxsize)

index <- caret::groupKFold(data[[group]], k = k)
arguments <- c(
arguments,
trControl = caret::trainControl(method = "cv",
index = index)
)
}, error = function(e){
stop(e)
})
}

### call caret::train
model_calibration <- do.call(caret::train, arguments)

predicted_data <- predict(model_calibration, data[outliers, ])
predicted_data <- predict(model_calibration, data_calib[outliers, ])
Y <- outcome(model_calibration)

residuals <- Y - predicted_data
Expand Down
10 changes: 6 additions & 4 deletions man/t_outlier_test.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit bd81c5b

Please sign in to comment.