Skip to content

Commit

Permalink
Updated R scoring server to support standard MLflow models endpoints. (
Browse files Browse the repository at this point in the history
…mlflow#1706)

* Updated R scorign server to support endpoitns compatible with mlflow models rest api.

* nit

* Fix.

* Addressed review comments.

* Addressed review comments.
  • Loading branch information
tomasatdatabricks committed Aug 9, 2019
1 parent 5cb973a commit 31287c6
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 7 deletions.
6 changes: 3 additions & 3 deletions mlflow/R/mlflow/R/cli.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ mlflow_cli <- function(...,
env <- if (is.null(client)) list() else client$get_cli_env()
args <- list(...)
verbose <- mlflow_is_verbose()

python <- dirname(python_bin())
mlflow_bin <- python_mlflow_bin()
env <- modifyList(list(
PATH = paste(python, Sys.getenv("PATH"), sep = ":"),
MLFLOW_CONDA_HOME = python_conda_home(),
MLFLOW_TRACKING_URI = mlflow_get_tracking_uri()
MLFLOW_TRACKING_URI = mlflow_get_tracking_uri(),
MLFLOW_BIN = mlflow_bin,
MLFLOW_PYTHON_BIN = python_bin()
), env)
if (is.null(stderr_callback)) {
stderr_callback <- function(x, p) {
cat(x, file = stderr())
}
}

with_envvar(env, {
if (background) {
result <- process$new(mlflow_bin, args = unlist(args), echo_cmd = verbose, supervise = TRUE)
Expand Down
49 changes: 47 additions & 2 deletions mlflow/R/mlflow/R/model-serve.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

#' Serve an RFunc MLflow Model
#'
#' Serves an RFunc MLflow model as a local web API.
#' Serves an RFunc MLflow model as a local REST API server. This interface provides similar
#' functionality to ``mlflow models serve`` cli command, however, it can only be used to deploy
#' models that include RFunc flavor. The deployed server supports standard mlflow models interface
#' with /ping and /invocation endpoints. In addition, R function models also support deprecated
#' /predict endpoint for generating predictions. The /predict endpoint will be removed in a future
#' version of mlflow.
#'
#' @template roxlate-model-uri
#' @param host Address to use to serve model, as a string.
Expand Down Expand Up @@ -94,7 +99,6 @@ serve_invalid_request <- function(message = NULL) {

serve_prediction <- function(json_raw, model, ...) {
mlflow_verbose_message("Serving prediction: ", json_raw)

df <- data.frame()
if (length(json_raw) > 0) {
df <- jsonlite::fromJSON(
Expand Down Expand Up @@ -157,6 +161,47 @@ serve_handlers <- function(host, port, ...) {
))
)
},
"^/ping" = function(req, model) {
if (!is.na(model) && !is.null(model)) {
res <- list(status = 200L,
headers = list(
"Content-Type" = paste0(serve_content_type("json"), "; charset=UTF-8")
),
body = ""
)
res
} else {
list(status = 404L,
headers = list(
"Content-Type" = paste0(serve_content_type("json"), "; charset=UTF-8")
)
)
}
},
"^/invocation" = function(req, model) {
data_raw <- rawToChar(req$rook.input$read())
headers <- strsplit(req$HEADERS, "\n")
content_type <- headers$`content-type` %||% "application/json"
df <- switch( content_type,
"application/json" = parse_json(data_raw, "split"),
"application/json; format=pandas-split" = parse_json(data_raw, "split"),
"application/json; format=pandas-records" = parse_json(data_raw, "records"),
"application/json-numpy-split" = parse_json(data_raw, "split"),
"application/json" = parse_json(data_raw, "split"),
"text/csv" = utils::read.csv(text = data_raw, stringsAsFactors = FALSE),
stop("Unsupported input format.")
)
results <- mlflow_predict(model, df, ...)
list(
status = 200L,
headers = list(
"Content-Type" = paste0(serve_content_type("json"), "; charset=UTF-8")
),
body = charToRaw(enc2utf8(
jsonlite::toJSON(results, auto_unbox = TRUE, digits = NA, simplifyVector = TRUE)
))
)
},
"^/[^/]*$" = function(req, model) {
serve_static_file_response("swagger", file.path("dist", req$PATH_INFO))
},
Expand Down
4 changes: 2 additions & 2 deletions mlflow/R/mlflow/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ supported_model_flavors <- function() {
parse_json <- function(input_path, json_format="split") {
switch(json_format,
split = {
json <- jsonlite::read_json(input_path, simplifyVector = TRUE)
json <- jsonlite::fromJSON(input_path, simplifyVector = TRUE)
elms <- names(json)
if (length(setdiff(elms, c("columns", "index", "data"))) != 0
|| length(setdiff(c("columns", "data"), elms) != 0)) {
Expand All @@ -195,7 +195,7 @@ parse_json <- function(input_path, json_format="split") {
names(df) <- json$columns
df
},
records = jsonlite::read_json(input_path, simplifyVector = TRUE),
records = jsonlite::fromJSON(input_path, simplifyVector = TRUE),
stop(paste("Unsupported JSON format", json_format,
". Supported formats are 'split' or 'records'"))
)
Expand Down
82 changes: 82 additions & 0 deletions mlflow/R/mlflow/tests/testthat/test-serve.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,85 @@ test_that("mlflow can serve a model function", {
tolerance = 1e-5
)
})


wait_for_server_to_start <- function(server_process) {
status_code <- 500
for (i in 1:5) {
tryCatch({
status_code <- httr::status_code(httr::GET("http://127.0.0.1:54321/ping"))
}, error = function(...) {
Sys.sleep(5)
})
}
if (status_code != 200) {
write("FAILED to start the server!", stderr())
error_text <- server_process$read_error()
stop("Failed to start the server!", error_text)
}
}

test_that("mlflow models server api works with R model function", {
model <- lm(Sepal.Width ~ Sepal.Length + Petal.Width, iris)
fn <- crate(~ stats::predict(model, .x), model = model)
mlflow_save_model(fn, path = "model")
expect_true(dir.exists("model"))
server_process <- mlflow:::mlflow_cli("models", "serve", "-m", "model", "-p", "54321",
background = TRUE)
tryCatch({
wait_for_server_to_start(server_process)
newdata <- iris[1:2, c("Sepal.Length", "Petal.Width")]
check_prediction <- function(http_prediction) {
if (is.character(http_prediction)) {
stop(http_prediction)
}
expect_equal(
unlist(http_prediction),
as.vector(predict(model, newdata)),
tolerance = 1e-5
)
}
# json records
check_prediction(
httr::content(
httr::POST(
"http://127.0.0.1:54321/invocation/",
httr::content_type("application/json; format=pandas-records"),
body = jsonlite::toJSON(as.list(newdata))
)
)
)
newdata_split <- list(columns = names(newdata), index = row.names(newdata),
data = as.matrix(newdata))
# json split
for (content_type in c("application/json",
"application/json; format=pandas-split",
"application/json-numpy-split")) {
check_prediction(
httr::content(
httr::POST(
"http://127.0.0.1:54321/invocation/",
httr::content_type(content_type),
body = jsonlite::toJSON(newdata_split)
)
)
)
}
# csv
csv_header <- paste(names(newdata), collapse = ", ")
csv_row_1 <- paste(newdata[1, ], collapse = ", ")
csv_row_2 <- paste(newdata[2, ], collapse = ", ")
newdata_csv <- paste(csv_header, csv_row_1, csv_row_2, "", sep = "\n")
check_prediction(
httr::content(
httr::POST(
"http://127.0.0.1:54321/invocation/",
httr::content_type("text/csv"),
body = newdata_csv
)
)
)
}, finally = {
server_process$kill()
})
})

0 comments on commit 31287c6

Please sign in to comment.