Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Serialization can lose boosting rounds and lead to mismatches in predict #6214

Open
david-cortes opened this issue Nov 26, 2023 · 0 comments

Comments

@david-cortes
Copy link
Contributor

The R package keeps a best_iteration number that it uses to serialize model objects. If one de-serializes a model that had a best iteration number lower than the actual number of iterations, further rounds after that will be lost and the predictions will not match.

I think the best solution here would be to create a new C-level function to dump the model that would not take any additional parameters, dumping instead everything in the booster, and use that to handle automated serialization of models in R/Python.

Example using code from the tests:

library(lightgbm)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train

set.seed(708L)
dtrain <- lgb.Dataset(
  data = train$data
  , label = train$label
  , params = list(num_threads = 1)
)
dvalid <- lgb.Dataset(
  data = test$data
  , label = test$label
  , params = list(num_threads = 1)
)

early_stopping_rounds <- 5L
bst_auc <- lgb.train(
  params = list(
    objective = "binary"
    , metric = "auc"
    , max_depth = 3L
    , early_stopping_rounds = early_stopping_rounds
    , verbose = 1
    , num_threads = 1
  )
  , data = dtrain
  , nrounds = 100L
  , valids = list(
    "valid1" = dvalid
  )
)

pred_old <- predict(bst_auc, agaricus.train$data[1:5, ], num_iteration=99)
fname <- file.path(tempdir(), "lgb_model.Rds")
saveRDS(bst_auc, fname)
bst_new <- readRDS(fname)
lgb.restore_handle(bst_new)
pred_new <- predict(bst_new, agaricus.train$data[1:5, ], num_iteration=99)

print(pred_old - pred_new)
[1]  0.09444400 -0.06709325 -0.06576094  0.09444400 -0.06249008
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants