Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add arg.params,aux.params for mx.model.FeedForward.create function [#…
Browse files Browse the repository at this point in the history
…1543]

modify basic_model.R to test to load model and continue training [#1543]
  • Loading branch information
ziyeqinghan committed Apr 3, 2016
1 parent e04ba1d commit 677f067
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
8 changes: 8 additions & 0 deletions R-package/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,14 @@ mx.model.select.layout.predict <- function(X, model) {
#' The parameter synchronization scheme in multiple devices.
#' @param verbose logical (default=TRUE)
#' Specifies whether to print information on the iterations during training.
#' @param arg.params list, optional
#' Model parameter, list of name to NDArray of net's weights.
#' @param aux.params list, optional
#' Model parameter, list of name to NDArray of net's auxiliary states.
#' @return model A trained mxnet model.
#'
#' @export

mx.model.FeedForward.create <-
function(symbol, X, y=NULL, ctx=NULL,
num.round=10, optimizer="sgd",
Expand All @@ -390,6 +395,7 @@ function(symbol, X, y=NULL, ctx=NULL,
array.batch.size=128, array.layout="auto",
kvstore="local",
verbose=TRUE,
arg.params=NULL, aux.params=NULL,
...) {
if (is.array(X) || is.matrix(X)) {
if (array.layout == "auto") {
Expand All @@ -406,6 +412,8 @@ function(symbol, X, y=NULL, ctx=NULL,
}
input.shape <- dim((X$value())$data)
params <- mx.model.init.params(symbol, input.shape, initializer, mx.cpu())
if (!is.null(arg.params)) params$arg.params <- arg.params
if (!is.null(aux.params)) params$aux.params <- aux.params
if (is.null(ctx)) ctx <- mx.ctx.default()
if (is.mx.context(ctx)) {
ctx <- list(ctx)
Expand Down
31 changes: 31 additions & 0 deletions R-package/demo/basic_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,34 @@ accuracy <- function(label, pred) {
print(paste0("Finish prediction... accuracy=", accuracy(label, pred)))
print(paste0("Finish prediction... accuracy2=", accuracy(label, pred2)))



# load the model
model <- mx.model.load("chkpt", 1)

#continue training with some new arguments
model <- mx.model.FeedForward.create(model$symbol, X=dtrain, eval.data=dtest,
ctx=devices, num.round=5,
learning.rate=0.1, momentum=0.9,
epoch.end.callback=mx.callback.save.checkpoint("reload_chkpt"),
batch.end.callback=mx.callback.log.train.metric(100),
arg.params=model$arg.params, aux.params=model$aux.params)

# do prediction
pred <- predict(model, dtest)
label <- mx.io.extract(dtest, "label")
dataX <- mx.io.extract(dtest, "data")
# Predict with R's array
pred2 <- predict(model, X=dataX)

accuracy <- function(label, pred) {
ypred = max.col(t(as.array(pred)))
return(sum((as.array(label) + 1) == ypred) / length(label))
}

print(paste0("Finish prediction... accuracy=", accuracy(label, pred)))
print(paste0("Finish prediction... accuracy2=", accuracy(label, pred2)))




9 changes: 8 additions & 1 deletion R-package/man/mx.model.FeedForward.create.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 677f067

Please sign in to comment.