Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-357] New Scala API Design (NDArray) #10787

Merged
merged 5 commits into from
May 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ object NDArray {

private val functions: Map[String, NDArrayFunction] = initNDArrayModule()

val api = NDArrayAPI

private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
froms.foreach { from =>
val weakRef = new WeakReference(from)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnet
@AddNDArrayAPIs(false)
/**
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
object NDArrayAPI {
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,67 +29,134 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot
private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addDefs
}

private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
}

private[mxnet] object NDArrayMacro {
case class NDArrayFunction(handle: NDArrayHandle)
case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])

// scalastyle:off havetype
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
impl(c)(false, annottees: _*)
impl(c)(annottees: _*)
}
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
typeSafeAPIImpl(c)(annottees: _*)
}
// scalastyle:off havetype

private val ndarrayFunctions: Map[String, NDArrayFunction] = initNDArrayModule()
private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()

private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = {
private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val isContrib: Boolean = c.prefix.tree match {
case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
}

val newNDArrayFunctions = {
if (isContrib) ndarrayFunctions.filter(_._1.startsWith("_contrib_"))
else ndarrayFunctions.filter(!_._1.startsWith("_contrib_"))
if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
}

val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) =>
val functionScope = {
if (isContrib) Modifiers()
else {
if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers()
val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
val funcName = NDArrayfunction.name
val termName = TermName(funcName)
if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) {
Seq(
// scalastyle:off
// e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
// e.g def transpose(args: Any*)
q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
// scalastyle:on
)
} else {
// Default private
Seq(
// scalastyle:off
q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
// scalastyle:on
)
}
}
val newName = {
if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
else funcName
}
val termName = TermName(funcName)
// It will generate definition something like,
Seq(
// scalastyle:off
// def transpose(kwargs: Map[String, Any] = null)(args: Any*)
q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
// def transpose(args: Any*)
q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
// scalastyle:on
)

structGeneration(c)(functionDefs, annottees : _*)
}

private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
import c.universe._

val isContrib: Boolean = c.prefix.tree match {
case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
val newNDArrayFunctions = {
if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
}

val functionDefs = newNDArrayFunctions map { ndarrayfunction =>

// Construct argument field
var argDef = ListBuffer[String]()
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
if (ndarrayarg.isOptional) {
argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${ndarrayarg.argType}"
}
var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName
if (ndarrayarg.isOptional) {
base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
}
impl += base
})
// scalastyle:off
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArray"
var finalStr = s"def ${ndarrayfunction.name}New"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
}

structGeneration(c)(functionDefs, annottees : _*)
}

private def structGeneration(c: blackbox.Context)
(funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
: c.Expr[Any] = {
import c.universe._
val inputs = annottees.map(_.tree).toList
// pattern match on the inputs
val modDefs = inputs map {
case ClassDef(mods, name, something, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ functionDefs)
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
ClassDef(mods, name, something, q)
case ModuleDef(mods, name, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ functionDefs)
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
Expand All @@ -102,20 +169,80 @@ private[mxnet] object NDArrayMacro {
result
}


// Convert C++ Types to Scala Types
private def typeConversion(in : String, argType : String = "") : String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have these functions shared with those in SymbolMacro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change it in Document integration component PR, we can do a global migration on Symbol, NDArray there. Any recommendation where we put these functions? As a single file or inside SymbolMacros.

in match {
case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.NDArray"
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> "Array[org.apache.mxnet.NDArray]"
case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
case "int" | "intorNone" | "int(non-negative)" => "Int"
case "long" | "long(non-negative)" => "Long"
case "double" | "doubleorNone" => "Double"
case "string" => "String"
case "boolean" | "booleanorNone" => "Boolean"
case "tupleof<float>" | "tupleof<double>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default, $argType")
}
}


/**
* By default, the argType come from the C++ API is a description more than a single word
* For Example:
* <C++ Type>, <Required/Optional>, <Default=>
* The three field shown above do not usually come at the same time
* This function used the above format to determine if the argument is
* optional, what is it Scala type and possibly pass in a default value
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
private def argumentCleaner(argType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
if (spaceRemoved.charAt(0)== '{') {
val endIdx = spaceRemoved.indexOf('}')
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
commaRemoved(0) = "string"
} else {
commaRemoved = spaceRemoved.split(",")
}
// Optional Field
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved(1).equals("optional"))
require(commaRemoved(2).startsWith("default="))
(typeConversion(commaRemoved(0), argType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType)
val tempOptional = tempType.equals("org.apache.mxnet.NDArray")
(tempType, tempOptional)
} else {
throw new IllegalArgumentException(
s"Unrecognized arg field: $argType, ${commaRemoved.length}")
}

}


// List and add all the atomic symbol functions to current module.
private def initNDArrayModule(): Map[String, NDArrayFunction] = {
private def initNDArrayModule(): List[NDArrayFunction] = {
val opNames = ListBuffer.empty[String]
_LIB.mxListAllOpNames(opNames)
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeNDArrayFunction(opHandle.value, opName)
}).toMap
}).toList
}

// Create an atomic symbol function by handle and function name.
private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
: (String, NDArrayFunction) = {
: NDArrayFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
Expand All @@ -136,10 +263,14 @@ private[mxnet] object NDArrayMacro {
val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
// scalastyle:off println
if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
&& System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
&& System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
println("NDArray function definition:\n" + docStr)
}
// scalastyle:on println
(aliasName, new NDArrayFunction(handle))
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption = argumentCleaner(argType)
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
}
}