Skip to content

Commit

Permalink
Add a CLI-based fallback ArtifactRepository in Java (mlflow#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Aug 29, 2018
1 parent 89efc29 commit b79c870
Show file tree
Hide file tree
Showing 8 changed files with 526 additions and 8 deletions.
14 changes: 14 additions & 0 deletions mlflow/java/client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>1.3.9</version>
</dependency>
<dependency>
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -72,6 +85,7 @@
<include>com.google.guava:guava</include>
<include>com.google.code.gson:gson</include>
<include>commons-codec:commons-codec</include>
<include>commons-io:commons-io</include>
<include>commons-logging:commons-logging</include>
<include>org.apache.httpcomponents:httpcore</include>
</includes>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package org.mlflow.artifacts;

import javax.annotation.Nonnull;
import java.io.File;
import java.util.List;

import org.mlflow.api.proto.Service.FileInfo;

/**
* Allows logging, listing, and downloading artifacts against a remote Artifact Repository.
* This is used for storing potentially-large objects associated with MLflow runs.
*/
public interface ArtifactRepository {

/**
* Uploads the given local file to the run's root artifact directory. For example,
*
* logArtifact("/my/localModel")
* listArtifacts() // returns "localModel"
*
* @param localFile File to upload. Must exist, and must be a simple file (not a directory).
*/
void logArtifact(File localFile);

/**
* Uploads the given local file to an artifactPath within the run's root directory. For example,
*
* logArtifact("/my/localModel", "model")
* listArtifacts("model") // returns "model/localModel"
*
* (i.e., the localModel file is now available in model/localModel).
*
* @param localFile File to upload. Must exist, and must be a simple file (not a directory).
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
void logArtifact(File localFile, @Nonnull String artifactPath);

/**
* Uploads all files within the given local director the run's root artifact directory.
* For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
* logArtifacts("/my/local/dir")
* listArtifacts() // returns "file1" and "file2"
*
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
*/
void logArtifacts(File localDir);


/**
* Uploads all files within the given local director an artifactPath within the run's root
* artifact directory. For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
* logArtifacts("/my/local/dir", "model")
* listArtifacts("model") // returns "model/file1" and "model/file2"
*
* (i.e., the contents of the local directory are now available in model/).
*
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
void logArtifacts(File localDir, @Nonnull String artifactPath);

/**
* Lists the artifacts immediately under the run's root artifact directory. This does not
* recursively list; instead, it will return FileInfos with isDir=true where further
* listing may be done.
*/
List<FileInfo> listArtifacts();

/**
* Lists the artifacts immediately under the given artifactPath within the run's root artifact
* irectory. This does not recursively list; instead, it will return FileInfos with isDir=true
* where further listing may be done.
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
List<FileInfo> listArtifacts(@Nonnull String artifactPath);

/**
* Returns a local directory containing *all* artifacts within the run's artifact directory.
* Note that this will download the entire directory path, and so may be expensive if
* the directory a lot of data.
*/
File downloadArtifacts();

/**
* Returns a local file or directory containing all artifacts within the given artifactPath
* within the run's root artifactDirectory. For example, if "model/file1" and "model/file2"
* exist within the artifact directory, then
*
* downloadArtifacts("model") // returns a local directory containing "file1" and "file2"
* downloadArtifacts("model/file1") // returns a local *file* with the contents of file1.
*
* Note that this will download the entire subdirectory path, and so may be expensive if
* the subdirectory a lot of data.
*
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
File downloadArtifacts(@Nonnull String artifactPath);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
package org.mlflow.artifacts;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.apache.log4j.Logger;

import org.mlflow.api.proto.Service;
import org.mlflow.tracking.MlflowClientException;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;

/**
* Shells out to the 'mlflow' command line utility to upload, download, and list artifacts. This
* is used as a fallback to implement any artifact repositories which are not natively supported
* within Java.
*
* We require that 'mlflow' is available in the system path.
*/
public class CliBasedArtifactRepository implements ArtifactRepository {
private static final Logger logger = Logger.getLogger(CliBasedArtifactRepository.class);

// Global check if we ever successfully loaded 'mlflow'. This allows us to print a more
// helpful error message if the executable is not in the path.
private static final AtomicBoolean mlflowSuccessfullyLoaded = new AtomicBoolean(false);

// Name of the mlflow CLI utility which can be exec'd directly.
private final String mlflowExecutable = "mlflow";

// Base directory of the artifactory, used to let the user know why this repository was chosen.
private final String artifactBaseDir;

// Run ID this repository is targeting.
private final String runId;

// Used to pass the MLFLOW_TRACKING_URI on to the mlflow process.
private final MlflowHostCredsProvider hostCredsProvider;

public CliBasedArtifactRepository(
String artifactBaseDir,
String runId,
MlflowHostCredsProvider hostCredsProvider) {
this.artifactBaseDir = artifactBaseDir;
this.runId = runId;
this.hostCredsProvider = hostCredsProvider;
}

@Override
public void logArtifact(File localFile, String artifactPath) {
checkMlflowAccessible();
if (!localFile.exists()) {
throw new MlflowClientException("Local file does not exist: " + localFile);
}
if (localFile.isDirectory()) {
throw new MlflowClientException("Local path points to a directory. Use logArtifacts" +
" instead: " + localFile);
}

List<String> baseCommand = Lists.newArrayList(
mlflowExecutable, "artifacts", "log-artifact", "--local-file", localFile.toString());
List<String> command = appendRunIdArtifactPath(baseCommand, runId, artifactPath);
String tag = "log file " + localFile + " to " + getTargetIdentifier(artifactPath);
forkProcess(command, tag);
}

@Override
public void logArtifact(File localFile) {
logArtifact(localFile, null);
}

@Override
public void logArtifacts(File localDir, String artifactPath) {
checkMlflowAccessible();
if (!localDir.exists()) {
throw new MlflowClientException("Local file does not exist: " + localDir);
}
if (localDir.isFile()) {
throw new MlflowClientException("Local path points to a file. Use logArtifact" +
" instead: " + localDir);
}

List<String> baseCommand = Lists.newArrayList(
mlflowExecutable, "artifacts", "log-artifacts", "--local-dir", localDir.toString());
List<String> command = appendRunIdArtifactPath(baseCommand, runId, artifactPath);
String tag = "log dir " + localDir + " to " + getTargetIdentifier(artifactPath);
forkProcess(command, tag);
}

@Override
public void logArtifacts(File localDir) {
logArtifacts(localDir, null);
}

@Override
public File downloadArtifacts(String artifactPath) {
checkMlflowAccessible();
String tag = "download artifacts for " + getTargetIdentifier(artifactPath);
List<String> command = appendRunIdArtifactPath(
Lists.newArrayList(mlflowExecutable, "artifacts", "download"), runId, artifactPath);
String localPath = forkProcess(command, tag).trim();
return new File(localPath);
}

@Override
public File downloadArtifacts() {
return downloadArtifacts(null);
}

@Override
public List<Service.FileInfo> listArtifacts(String artifactPath) {
checkMlflowAccessible();
String tag = "list artifacts in " + getTargetIdentifier(artifactPath);
List<String> command = appendRunIdArtifactPath(
Lists.newArrayList(mlflowExecutable, "artifacts", "list"), runId, artifactPath);
String jsonOutput = forkProcess(command, tag);
return parseFileInfos(jsonOutput);
}

@Override
public List<Service.FileInfo> listArtifacts() {
return listArtifacts(null);
}

/** Parses a list of JSON FileInfos, as returned by 'mlflow artifacts list'. */
private List<Service.FileInfo> parseFileInfos(String json) {
// The protobuf deserializer doesn't allow us to directly deserialize a list, so we
// deserialize a list-of-dictionaries, and then reserialize each dictionary to pass it to
// the protobuf deserializer.
Gson gson = new Gson();
Type type = new TypeToken<List<Map<String, Object>>>(){}.getType();
List<Map<String, Object>> listOfDicts = gson.fromJson(json, type);
List<Service.FileInfo> fileInfos = new ArrayList<>();
for (Map<String, Object> dict: listOfDicts) {
String fileInfoJson = gson.toJson(dict);
try {
Service.FileInfo.Builder builder = Service.FileInfo.newBuilder();
JsonFormat.parser().merge(fileInfoJson, builder);
fileInfos.add(builder.build());
} catch (InvalidProtocolBufferException e) {
throw new MlflowClientException("Failed to deserialize JSON into FileInfo: " + json, e);
}
}
return fileInfos;
}

/**
* Checks whether the 'mlflow' executable is available, and throws a nice error if not.
* If this method has ever run successfully before (in the entire JVM), we will not rerun it.
*/
private void checkMlflowAccessible() {
if (mlflowSuccessfullyLoaded.get()) {
return;
}

try {
String tag = "get mlflow version";
String mlflowVersion = forkProcess(Lists.newArrayList(mlflowExecutable, "--version"), tag);
logger.info("Found local mlflow executable with version=" + mlflowVersion);
mlflowSuccessfullyLoaded.set(true);
} catch (MlflowClientException e) {
String errorMessage = String.format("Failed to exec process %s, needed to access artifacts " +
"within the non-Java-native artifact store at '%s'. Please make sure mlflow is " +
"available on your local system path (e.g.," + "from 'pip install mlflow')",
mlflowExecutable, artifactBaseDir);
throw new MlflowClientException(errorMessage, e);
}
}

/**
* Forks the given mlflow command and awaits for its successful completion.
*
* @param command Command used to fork the process.
* @param tag User-facing tag which will be used to identify what we were trying to do
* in the case of a failure.
* @return raw stdout of the process, decoded as a utf-8 string
* @throws MlflowClientException if the process exits with a non-zero exit code, or anything
* else goes wrong.
*/
private String forkProcess(List<String> command, String tag) {
String stdout;
Process process = null;
try {
ProcessBuilder pb = new ProcessBuilder(command);
// TODO(aaron) Figure out a way to pass the other fields of the host-creds.
pb.environment().put("MLFLOW_TRACKING_URI", hostCredsProvider.getHostCreds().getHost());
process = pb.start();
stdout = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
int exitValue = process.waitFor();
if (exitValue != 0) {
throw new MlflowClientException("Failed to " + tag + ". Error: " +
getErrorBestEffort(process));
}
} catch (IOException | InterruptedException e) {
throw new MlflowClientException("Failed to fork mlflow process to " + tag +
". Process stderr: " + getErrorBestEffort(process), e);
}
return stdout;
}

/** Does our best to get the process's stderr, or returns a dummy return value. */
private String getErrorBestEffort(Process process) {
if (process == null) {
return "<process not started>";
}
try {
return IOUtils.toString(process.getErrorStream(), StandardCharsets.UTF_8);
} catch (IOException e) {
return "<error unknown>";
}
}

/** Appends --run-id $runId and --artifact-path $artifactPath, omitting artifactPath if null. */
private List<String> appendRunIdArtifactPath(
List<String> baseCommand,
String runId,
String artifactPath) {
baseCommand.add("--run-id");
baseCommand.add(runId);
if (artifactPath != null) {
baseCommand.add("--artifact-path");
baseCommand.add(artifactPath);
}
return baseCommand;
}

/** Returns user-facing identifier "runId=abc, artifactId=/foo", omitting artifactPath if null. */
private String getTargetIdentifier(String artifactPath) {
String identifier = "runId=" + runId;
if (artifactPath != null) {
return identifier + ", artifactPath=" + artifactPath;
}
return identifier;
}
}
Loading

0 comments on commit b79c870

Please sign in to comment.