Skip to content

Commit

Permalink
[SYSTEMML-508] Extend "executeScript" In MLContext To Accept PyDML.
Browse files Browse the repository at this point in the history
This adds the ability to run PyDML via `executeScript` from both Python and Scala.

NOTE:  This is simply a quick addition, as `MLContext` is receiving a major overhaul that will overwrite all of this.

Closes apache#139.
  • Loading branch information
dusenberrymw committed May 9, 2016
1 parent 5bde577 commit a1f1cf5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
41 changes: 36 additions & 5 deletions src/main/java/org/apache/sysml/api/MLContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,8 @@ private void initializeSparkListener(SparkContext sc) throws DMLRuntimeException
}

/**
* Experimental API. Not supported in Python MLContext API.
* Execute a script stored in a string.
*
* @param dmlScript
* @return
* @throws IOException
Expand All @@ -1127,31 +1128,61 @@ private void initializeSparkListener(SparkContext sc) throws DMLRuntimeException
*/
public MLOutput executeScript(String dmlScript)
throws IOException, DMLException {
return compileAndExecuteScript(dmlScript, null, false, false, false, null);
return executeScript(dmlScript, false);
}


public MLOutput executeScript(String dmlScript, boolean isPyDML)
throws IOException, DMLException {
return executeScript(dmlScript, isPyDML, null);
}

public MLOutput executeScript(String dmlScript, String configFilePath)
throws IOException, DMLException {
return compileAndExecuteScript(dmlScript, null, false, false, false, configFilePath);
return executeScript(dmlScript, false, configFilePath);
}

public MLOutput executeScript(String dmlScript, boolean isPyDML, String configFilePath)
throws IOException, DMLException {
return compileAndExecuteScript(dmlScript, null, false, false, isPyDML, configFilePath);
}

public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs)
throws IOException, DMLException {
return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), null);
}

public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs, boolean isPyDML)
throws IOException, DMLException {
return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), isPyDML, null);
}

public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs, String configFilePath)
throws IOException, DMLException {
return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), configFilePath);
}

public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs, boolean isPyDML, String configFilePath)
throws IOException, DMLException {
return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), isPyDML, configFilePath);
}

public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs)
throws IOException, DMLException {
return executeScript(dmlScript, namedArgs, null);
}

public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, boolean isPyDML)
throws IOException, DMLException {
return executeScript(dmlScript, namedArgs, isPyDML, null);
}

public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, String configFilePath)
throws IOException, DMLException {
return executeScript(dmlScript, namedArgs, false, configFilePath);
}

public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, boolean isPyDML, String configFilePath)
throws IOException, DMLException {
String [] args = new String[namedArgs.size()];
int i = 0;
for(Entry<String, String> entry : namedArgs.entrySet()) {
Expand All @@ -1161,7 +1192,7 @@ public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, S
args[i] = entry.getKey() + "=" + entry.getValue();
i++;
}
return compileAndExecuteScript(dmlScript, args, false, true, false, configFilePath);
return compileAndExecuteScript(dmlScript, args, false, true, isPyDML, configFilePath);
}

private void checkIfRegisteringInputAllowed() throws DMLRuntimeException {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/apache/sysml/api/python/SystemML.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def execute(self, dmlScriptFilePath, *args):
except Py4JJavaError:
traceback.print_exc()

def executeScript(self, dmlScript, nargs=None, outputs=None, configFilePath=None):
def executeScript(self, dmlScript, nargs=None, outputs=None, isPyDML=False, configFilePath=None):
"""
Executes the script in spark-mode by passing the arguments to the
MLContext java class.
Expand All @@ -125,7 +125,7 @@ def executeScript(self, dmlScript, nargs=None, outputs=None, configFilePath=None
self.registerOutput(out)

# Execute script
jml_out = self.ml.executeScript(dmlScript, nargs, configFilePath)
jml_out = self.ml.executeScript(dmlScript, nargs, isPyDML, configFilePath)
ml_out = MLOutput(jml_out, self.sc)
return ml_out
except Py4JJavaError:
Expand Down

0 comments on commit a1f1cf5

Please sign in to comment.