Skip to content

Commit

Permalink
[MXNET-357] New Scala API Design (Symbol) (apache#10660)
Browse files Browse the repository at this point in the history
* Simplfied current Macros impl to Quasiquote
* Change the Symbol Function Field, add SymbolArg
* Fix the Macros problem, disable the hidden function _
* Add Implementation for New API
* Add examples and comments
* Add _contrib_ support
* New namespace for Symbol API
* Change names and add comments
* add TODOs and name changes
* Add relative path to MXNET_BASEDIR
* Update Base.scala
  • Loading branch information
lanking520 authored and nswamy committed May 14, 2018
1 parent 8b53a3d commit b011ecc
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,8 @@ object Symbol {
private val functions: Map[String, SymbolFunction] = initSymbolModule()
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)

val api = SymbolAPI

def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
}
Expand Down
26 changes: 26 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnet


@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
object SymbolAPI {
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,40 @@ object TrainMnist {
// multi-layer perceptron
def getMlp: Symbol = {
val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))

val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2")
val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3")
val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
mlp
}

// LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
// Haffner. "Gradient-based learning applied to document recognition."
// Proceedings of the IEEE (1998)

def getLenet: Symbol = {
val data = Symbol.Variable("data")
// first conv
val conv1 = Symbol.Convolution()()(
Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
"kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20)
val tanh1 = Symbol.api.tanh(data = Some(conv1))
val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// second conv
val conv2 = Symbol.Convolution()()(
Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
"kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50)
val tanh2 = Symbol.api.tanh(data = Some(conv2))
val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// first fullc
val flatten = Symbol.Flatten()()(Map("data" -> pool2))
val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500))
val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
val flatten = Symbol.api.Flatten(data = Some(pool2))
val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
val tanh3 = Symbol.api.tanh(data = Some(fc1))
// second fullc
val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10))
val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10)
// loss
val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
lenet
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ object Base {

@throws(classOf[UnsatisfiedLinkError])
private def tryLoadInitLibrary(): Unit = {
val baseDir = System.getProperty("user.dir") + "/init-native"
var baseDir = System.getProperty("user.dir") + "/init-native"
// TODO(lanKing520) Update this to use relative path to the MXNet director.
// TODO(lanking520) baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"
if (System.getenv().containsKey("MXNET_BASEDIR")) {
baseDir = sys.env("MXNET_BASEDIR")
}
val os = System.getProperty("os.name")
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
Expand Down
38 changes: 38 additions & 0 deletions scala-package/macros/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,42 @@
<type>${libtype}</type>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<configuration>
<environmentVariables>
<MXNET_BASEDIR>${project.parent.basedir}/init-native</MXNET_BASEDIR>
</environmentVariables>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
</plugins>
</build>

</project>
Loading

0 comments on commit b011ecc

Please sign in to comment.