Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
added a verbose option to mx.mlp in R package. close #1608
Browse files Browse the repository at this point in the history
  • Loading branch information
Qiang Kou committed Mar 20, 2016
1 parent fc922f0 commit 569c486
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 116 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@ List of Contributors
* [Dan Becker](https://github.com/dansbecker)
* [Yun Yan](https://github.com/Puriney)
* [Tao Wei](https://github.com/taoari)
* [Max Kuhn](https://github.com/topepo)
10 changes: 5 additions & 5 deletions R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ mx.metric.logger <- setRefClass("mx.metric.logger", fields = list(train = "numer
#' Log training metric each period
#' @export
mx.callback.log.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env) {
function(iteration, nbatch, env, verbose=TRUE) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
result <- env$metric$get(env$train.metric)
if (nbatch != 0)
if (nbatch != 0 & verbose)
cat(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value, "\n"))
if (!is.null(logger)) {
if (class(logger) != "mx.metric.logger") {
Expand All @@ -16,7 +16,7 @@ mx.callback.log.train.metric <- function(period, logger=NULL) {
logger$train <- c(logger$train, result$value)
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (nbatch != 0)
if (nbatch != 0 & verbose)
cat(paste0("Batch [", nbatch, "] Validation-", result$name, "=", result$value, "\n"))
logger$eval <- c(logger$eval, result$value)
}
Expand All @@ -33,10 +33,10 @@ mx.callback.log.train.metric <- function(period, logger=NULL) {
#'
#' @export
mx.callback.save.checkpoint <- function(prefix, period=1) {
function(iteration, nbatch, env) {
function(iteration, nbatch, env, verbose=TRUE) {
if (iteration %% period == 0) {
mx.model.save(env$model, prefix, iteration)
cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
if(verbose) cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
}
return(TRUE)
}
Expand Down
5 changes: 2 additions & 3 deletions R-package/R/lr_scheduler.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#' @export
FactorScheduler <- function(step, factor_val) {
FactorScheduler <- function(step, factor_val, verbose=TRUE) {
function(optimizerEnv){
num_update <- optimizerEnv$num_update
count <- optimizerEnv$count
Expand All @@ -9,10 +9,9 @@ FactorScheduler <- function(step, factor_val) {
if (num_update > count + step){
count = count + step
lr = lr * factor_val
cat(paste0("Update[", num_update, "]: learning rate is changed to ", lr, "\n"))
if(verbose) cat(paste0("Update[", num_update, "]: learning rate is changed to ", lr, "\n"))
optimizerEnv$lr <- lr
optimizerEnv$count <- count
}
# return(optimizerEnv)
}
}
31 changes: 18 additions & 13 deletions R-package/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mx.model.extract.model <- function(symbol, train.execs) {
}

# decide what type of kvstore to use
mx.model.create.kvstore <- function(kvstore, arg.params, ndevice) {
mx.model.create.kvstore <- function(kvstore, arg.params, ndevice, verbose=TRUE) {
if (is.MXKVStore(kvstore)) return (kvstore)
if (!is.character(kvstore)) {
stop("kvstore msut be either MXKVStore or a string")
Expand All @@ -85,7 +85,7 @@ mx.model.create.kvstore <- function(kvstore, arg.params, ndevice) {
} else {
kvstore <- 'local_allreduce_cpu'
}
cat(paste0("Auto-select kvstore type = ", kvstore, "\n"))
if(verbose) cat(paste0("Auto-select kvstore type = ", kvstore, "\n"))
}
return(mx.kv.create(kvstore))
}
Expand All @@ -98,9 +98,10 @@ mx.model.train <- function(symbol, ctx, input.shape,
metric,
epoch.end.callback,
batch.end.callback,
kvstore) {
kvstore,
verbose=TRUE) {
ndevice <- length(ctx)
cat(paste0("Start training with ", ndevice, " devices\n"))
if(verbose) cat(paste0("Start training with ", ndevice, " devices\n"))
# create the executors
sliceinfo <- mx.model.slice.shape(input.shape, ndevice)
train.execs <- lapply(1:ndevice, function(i) {
Expand Down Expand Up @@ -203,7 +204,7 @@ mx.model.train <- function(symbol, ctx, input.shape,
train.data$reset()
if (!is.null(metric)) {
result <- metric$get(train.metric)
cat(paste0("[", iteration, "] Train-", result$name, "=", result$value, "\n"))
if(verbose) cat(paste0("[", iteration, "] Train-", result$name, "=", result$value, "\n"))
}
if (!is.null(eval.data)) {
if (!is.null(metric)) {
Expand Down Expand Up @@ -237,7 +238,7 @@ mx.model.train <- function(symbol, ctx, input.shape,
eval.data$reset()
if (!is.null(metric)) {
result <- metric$get(eval.metric)
cat(paste0("[", iteration, "] Validation-", result$name, "=", result$value, "\n"))
if(verbose) cat(paste0("[", iteration, "] Validation-", result$name, "=", result$value, "\n"))
}
} else {
eval.metric <- NULL
Expand All @@ -247,7 +248,7 @@ mx.model.train <- function(symbol, ctx, input.shape,

epoch_continue <- TRUE
if (!is.null(epoch.end.callback)) {
epoch_continue <- epoch.end.callback(iteration, 0, environment())
epoch_continue <- epoch.end.callback(iteration, 0, environment(), verbose = verbose)
}

if (!epoch_continue) {
Expand Down Expand Up @@ -295,10 +296,10 @@ mx.model.select.layout.train <- function(X, y) {
stop("Cannot auto select array.layout, please specify this parameter")
}
if (rowmajor == 1) {
cat("Auto detect layout of input matrix, use rowmajor..\n")
warning("Auto detect layout of input matrix, use rowmajor..\n")
return("rowmajor")
} else{
cat("Auto detect layout input matrix, use colmajor..\n")
warning("Auto detect layout input matrix, use colmajor..\n")
return("colmajor")
}
}
Expand Down Expand Up @@ -333,10 +334,10 @@ mx.model.select.layout.predict <- function(X, model) {
stop("Cannot auto select array.layout, please specify this parameter")
}
if (rowmajor == 1) {
cat("Auto detect layout of input matrix, use rowmajor..\n")
warning("Auto detect layout of input matrix, use rowmajor..\n")
return("rowmajor")
} else{
cat("Auto detect layout input matrix, use colmajor..\n")
warning("Auto detect layout input matrix, use colmajor..\n")
return("colmajor")
}
}
Expand Down Expand Up @@ -375,6 +376,8 @@ mx.model.select.layout.predict <- function(X, model) {
#' and will report error when X is a square matrix to ask user to explicitly specify layout.
#' @param kvstore string (default="local")
#' The parameter synchronization scheme in multiple devices.
#' @param verbose logical (default=TRUE)
#' Specifies whether to print information on the iterations during training.
#' @return model A trained mxnet model.
#'
#' @export
Expand All @@ -386,6 +389,7 @@ function(symbol, X, y=NULL, ctx=NULL,
epoch.end.callback=NULL, batch.end.callback=NULL,
array.batch.size=128, array.layout="auto",
kvstore="local",
verbose=TRUE,
...) {
if (is.array(X) || is.matrix(X)) {
if (array.layout == "auto") {
Expand Down Expand Up @@ -429,15 +433,16 @@ function(symbol, X, y=NULL, ctx=NULL,
}
eval.data <- mx.model.init.iter(eval.data$data, eval.data$label, batch.size=array.batch.size, is.train = TRUE)
}
kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx))
kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), verbose=verbose)
model <- mx.model.train(symbol, ctx, input.shape,
params$arg.params, params$aux.params,
1, num.round, optimizer=optimizer,
train.data=X, eval.data=eval.data,
metric=eval.metric,
epoch.end.callback=epoch.end.callback,
batch.end.callback=batch.end.callback,
kvstore=kvstore)
kvstore=kvstore,
verbose=verbose)
return (model)
}

Expand Down
4 changes: 2 additions & 2 deletions R-package/R/mxnet_generated.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ mx.io.CSVIter <- function(...) {
#' Batch Param: Batch size.
#' @param round.batch boolean, optional, default=True
#' Batch Param: Use round robin to handle overflow batch.
#' @param prefetch.buffer long (non-negative), optional, default=4
#' @param prefetch.buffer , optional, default=4
#' Backend Param: Number of prefetched parameters
#' @param rand.crop boolean, optional, default=False
#' Augmentation Param: Whether to random crop on the image
Expand Down Expand Up @@ -354,7 +354,7 @@ mx.io.ImageRecordIter <- function(...) {
#' partition the data into multiple parts
#' @param part.index int, optional, default='0'
#' the index of the part will read
#' @param prefetch.buffer long (non-negative), optional, default=4
#' @param prefetch.buffer , optional, default=4
#' Backend Param: Number of prefetched parameters
#' @return iter The result mx.dataiter
#'
Expand Down
12 changes: 6 additions & 6 deletions R-package/vignettes/CallbackFunctionTutorial.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Basically, all callback functions follow the structure below:

```{r, eval=FALSE}
mx.callback.fun <- function() {
function(iteration, nbatch, env) {
function(iteration, nbatch, env, verbose) {
}
}
```
Expand All @@ -97,10 +97,10 @@ The `mx.callback.save.checkpoint` function below is stateless. It just get the m

```{r, eval=FALSE}
mx.callback.save.checkpoint <- function(prefix, period=1) {
function(iteration, nbatch, env) {
function(iteration, nbatch, env, verbose) {
if (iteration %% period == 0) {
mx.model.save(env$model, prefix, iteration)
cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
if(verbose) cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
}
return(TRUE)
}
Expand All @@ -112,11 +112,11 @@ process.

```{r, eval=FALSE}
mx.callback.log.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env) {
function(iteration, nbatch, env, verbose) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
result <- env$metric$get(env$train.metric)
if (nbatch != 0)
cat(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value, "\n"))
if(verbose) cat(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value, "\n"))
if (!is.null(logger)) {
if (class(logger) != "mx.metric.logger") {
stop("Invalid mx.metric.logger.")
Expand All @@ -142,7 +142,7 @@ Yes! You can stop the training early by `return(FALSE)`. See the examples below.

```{r}
mx.callback.early.stop <- function(eval.metric) {
function(iteration, nbatch, env) {
function(iteration, nbatch, env, verbose) {
if (!is.null(env$metric)) {
if (!is.null(eval.metric)) {
result <- env$metric$get(env$eval.metric)
Expand Down
Loading

0 comments on commit 569c486

Please sign in to comment.