forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a CLI-based fallback ArtifactRepository in Java (mlflow#394)
- Loading branch information
Showing
8 changed files
with
526 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
mlflow/java/client/src/main/java/org/mlflow/artifacts/ArtifactRepository.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package org.mlflow.artifacts; | ||
|
||
import javax.annotation.Nonnull; | ||
import java.io.File; | ||
import java.util.List; | ||
|
||
import org.mlflow.api.proto.Service.FileInfo; | ||
|
||
/** | ||
* Allows logging, listing, and downloading artifacts against a remote Artifact Repository. | ||
* This is used for storing potentially-large objects associated with MLflow runs. | ||
*/ | ||
public interface ArtifactRepository { | ||
|
||
/** | ||
* Uploads the given local file to the run's root artifact directory. For example, | ||
* | ||
* logArtifact("/my/localModel") | ||
* listArtifacts() // returns "localModel" | ||
* | ||
* @param localFile File to upload. Must exist, and must be a simple file (not a directory). | ||
*/ | ||
void logArtifact(File localFile); | ||
|
||
/** | ||
* Uploads the given local file to an artifactPath within the run's root directory. For example, | ||
* | ||
* logArtifact("/my/localModel", "model") | ||
* listArtifacts("model") // returns "model/localModel" | ||
* | ||
* (i.e., the localModel file is now available in model/localModel). | ||
* | ||
* @param localFile File to upload. Must exist, and must be a simple file (not a directory). | ||
* @param artifactPath Artifact path relative to the run's root directory. Should NOT | ||
* start with a /. | ||
*/ | ||
void logArtifact(File localFile, @Nonnull String artifactPath); | ||
|
||
/** | ||
* Uploads all files within the given local director the run's root artifact directory. | ||
* For example, if /my/local/dir/ contains two files "file1" and "file2", then | ||
* | ||
* logArtifacts("/my/local/dir") | ||
* listArtifacts() // returns "file1" and "file2" | ||
* | ||
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file). | ||
*/ | ||
void logArtifacts(File localDir); | ||
|
||
|
||
/** | ||
* Uploads all files within the given local director an artifactPath within the run's root | ||
* artifact directory. For example, if /my/local/dir/ contains two files "file1" and "file2", then | ||
* | ||
* logArtifacts("/my/local/dir", "model") | ||
* listArtifacts("model") // returns "model/file1" and "model/file2" | ||
* | ||
* (i.e., the contents of the local directory are now available in model/). | ||
* | ||
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file). | ||
* @param artifactPath Artifact path relative to the run's root directory. Should NOT | ||
* start with a /. | ||
*/ | ||
void logArtifacts(File localDir, @Nonnull String artifactPath); | ||
|
||
/** | ||
* Lists the artifacts immediately under the run's root artifact directory. This does not | ||
* recursively list; instead, it will return FileInfos with isDir=true where further | ||
* listing may be done. | ||
*/ | ||
List<FileInfo> listArtifacts(); | ||
|
||
/** | ||
* Lists the artifacts immediately under the given artifactPath within the run's root artifact | ||
* irectory. This does not recursively list; instead, it will return FileInfos with isDir=true | ||
* where further listing may be done. | ||
* @param artifactPath Artifact path relative to the run's root directory. Should NOT | ||
* start with a /. | ||
*/ | ||
List<FileInfo> listArtifacts(@Nonnull String artifactPath); | ||
|
||
/** | ||
* Returns a local directory containing *all* artifacts within the run's artifact directory. | ||
* Note that this will download the entire directory path, and so may be expensive if | ||
* the directory a lot of data. | ||
*/ | ||
File downloadArtifacts(); | ||
|
||
/** | ||
* Returns a local file or directory containing all artifacts within the given artifactPath | ||
* within the run's root artifactDirectory. For example, if "model/file1" and "model/file2" | ||
* exist within the artifact directory, then | ||
* | ||
* downloadArtifacts("model") // returns a local directory containing "file1" and "file2" | ||
* downloadArtifacts("model/file1") // returns a local *file* with the contents of file1. | ||
* | ||
* Note that this will download the entire subdirectory path, and so may be expensive if | ||
* the subdirectory a lot of data. | ||
* | ||
* @param artifactPath Artifact path relative to the run's root directory. Should NOT | ||
* start with a /. | ||
*/ | ||
File downloadArtifacts(@Nonnull String artifactPath); | ||
} |
245 changes: 245 additions & 0 deletions
245
mlflow/java/client/src/main/java/org/mlflow/artifacts/CliBasedArtifactRepository.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
package org.mlflow.artifacts; | ||
|
||
import java.io.File; | ||
import java.io.IOException; | ||
import java.lang.reflect.Type; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
|
||
import com.google.common.collect.Lists; | ||
import com.google.gson.Gson; | ||
import com.google.gson.reflect.TypeToken; | ||
import com.google.protobuf.InvalidProtocolBufferException; | ||
import com.google.protobuf.util.JsonFormat; | ||
import org.apache.commons.io.IOUtils; | ||
import org.apache.log4j.Logger; | ||
|
||
import org.mlflow.api.proto.Service; | ||
import org.mlflow.tracking.MlflowClientException; | ||
import org.mlflow.tracking.creds.MlflowHostCredsProvider; | ||
|
||
/** | ||
* Shells out to the 'mlflow' command line utility to upload, download, and list artifacts. This | ||
* is used as a fallback to implement any artifact repositories which are not natively supported | ||
* within Java. | ||
* | ||
* We require that 'mlflow' is available in the system path. | ||
*/ | ||
public class CliBasedArtifactRepository implements ArtifactRepository { | ||
private static final Logger logger = Logger.getLogger(CliBasedArtifactRepository.class); | ||
|
||
// Global check if we ever successfully loaded 'mlflow'. This allows us to print a more | ||
// helpful error message if the executable is not in the path. | ||
private static final AtomicBoolean mlflowSuccessfullyLoaded = new AtomicBoolean(false); | ||
|
||
// Name of the mlflow CLI utility which can be exec'd directly. | ||
private final String mlflowExecutable = "mlflow"; | ||
|
||
// Base directory of the artifactory, used to let the user know why this repository was chosen. | ||
private final String artifactBaseDir; | ||
|
||
// Run ID this repository is targeting. | ||
private final String runId; | ||
|
||
// Used to pass the MLFLOW_TRACKING_URI on to the mlflow process. | ||
private final MlflowHostCredsProvider hostCredsProvider; | ||
|
||
public CliBasedArtifactRepository( | ||
String artifactBaseDir, | ||
String runId, | ||
MlflowHostCredsProvider hostCredsProvider) { | ||
this.artifactBaseDir = artifactBaseDir; | ||
this.runId = runId; | ||
this.hostCredsProvider = hostCredsProvider; | ||
} | ||
|
||
@Override | ||
public void logArtifact(File localFile, String artifactPath) { | ||
checkMlflowAccessible(); | ||
if (!localFile.exists()) { | ||
throw new MlflowClientException("Local file does not exist: " + localFile); | ||
} | ||
if (localFile.isDirectory()) { | ||
throw new MlflowClientException("Local path points to a directory. Use logArtifacts" + | ||
" instead: " + localFile); | ||
} | ||
|
||
List<String> baseCommand = Lists.newArrayList( | ||
mlflowExecutable, "artifacts", "log-artifact", "--local-file", localFile.toString()); | ||
List<String> command = appendRunIdArtifactPath(baseCommand, runId, artifactPath); | ||
String tag = "log file " + localFile + " to " + getTargetIdentifier(artifactPath); | ||
forkProcess(command, tag); | ||
} | ||
|
||
@Override | ||
public void logArtifact(File localFile) { | ||
logArtifact(localFile, null); | ||
} | ||
|
||
@Override | ||
public void logArtifacts(File localDir, String artifactPath) { | ||
checkMlflowAccessible(); | ||
if (!localDir.exists()) { | ||
throw new MlflowClientException("Local file does not exist: " + localDir); | ||
} | ||
if (localDir.isFile()) { | ||
throw new MlflowClientException("Local path points to a file. Use logArtifact" + | ||
" instead: " + localDir); | ||
} | ||
|
||
List<String> baseCommand = Lists.newArrayList( | ||
mlflowExecutable, "artifacts", "log-artifacts", "--local-dir", localDir.toString()); | ||
List<String> command = appendRunIdArtifactPath(baseCommand, runId, artifactPath); | ||
String tag = "log dir " + localDir + " to " + getTargetIdentifier(artifactPath); | ||
forkProcess(command, tag); | ||
} | ||
|
||
@Override | ||
public void logArtifacts(File localDir) { | ||
logArtifacts(localDir, null); | ||
} | ||
|
||
@Override | ||
public File downloadArtifacts(String artifactPath) { | ||
checkMlflowAccessible(); | ||
String tag = "download artifacts for " + getTargetIdentifier(artifactPath); | ||
List<String> command = appendRunIdArtifactPath( | ||
Lists.newArrayList(mlflowExecutable, "artifacts", "download"), runId, artifactPath); | ||
String localPath = forkProcess(command, tag).trim(); | ||
return new File(localPath); | ||
} | ||
|
||
@Override | ||
public File downloadArtifacts() { | ||
return downloadArtifacts(null); | ||
} | ||
|
||
@Override | ||
public List<Service.FileInfo> listArtifacts(String artifactPath) { | ||
checkMlflowAccessible(); | ||
String tag = "list artifacts in " + getTargetIdentifier(artifactPath); | ||
List<String> command = appendRunIdArtifactPath( | ||
Lists.newArrayList(mlflowExecutable, "artifacts", "list"), runId, artifactPath); | ||
String jsonOutput = forkProcess(command, tag); | ||
return parseFileInfos(jsonOutput); | ||
} | ||
|
||
@Override | ||
public List<Service.FileInfo> listArtifacts() { | ||
return listArtifacts(null); | ||
} | ||
|
||
/** Parses a list of JSON FileInfos, as returned by 'mlflow artifacts list'. */ | ||
private List<Service.FileInfo> parseFileInfos(String json) { | ||
// The protobuf deserializer doesn't allow us to directly deserialize a list, so we | ||
// deserialize a list-of-dictionaries, and then reserialize each dictionary to pass it to | ||
// the protobuf deserializer. | ||
Gson gson = new Gson(); | ||
Type type = new TypeToken<List<Map<String, Object>>>(){}.getType(); | ||
List<Map<String, Object>> listOfDicts = gson.fromJson(json, type); | ||
List<Service.FileInfo> fileInfos = new ArrayList<>(); | ||
for (Map<String, Object> dict: listOfDicts) { | ||
String fileInfoJson = gson.toJson(dict); | ||
try { | ||
Service.FileInfo.Builder builder = Service.FileInfo.newBuilder(); | ||
JsonFormat.parser().merge(fileInfoJson, builder); | ||
fileInfos.add(builder.build()); | ||
} catch (InvalidProtocolBufferException e) { | ||
throw new MlflowClientException("Failed to deserialize JSON into FileInfo: " + json, e); | ||
} | ||
} | ||
return fileInfos; | ||
} | ||
|
||
/** | ||
* Checks whether the 'mlflow' executable is available, and throws a nice error if not. | ||
* If this method has ever run successfully before (in the entire JVM), we will not rerun it. | ||
*/ | ||
private void checkMlflowAccessible() { | ||
if (mlflowSuccessfullyLoaded.get()) { | ||
return; | ||
} | ||
|
||
try { | ||
String tag = "get mlflow version"; | ||
String mlflowVersion = forkProcess(Lists.newArrayList(mlflowExecutable, "--version"), tag); | ||
logger.info("Found local mlflow executable with version=" + mlflowVersion); | ||
mlflowSuccessfullyLoaded.set(true); | ||
} catch (MlflowClientException e) { | ||
String errorMessage = String.format("Failed to exec process %s, needed to access artifacts " + | ||
"within the non-Java-native artifact store at '%s'. Please make sure mlflow is " + | ||
"available on your local system path (e.g.," + "from 'pip install mlflow')", | ||
mlflowExecutable, artifactBaseDir); | ||
throw new MlflowClientException(errorMessage, e); | ||
} | ||
} | ||
|
||
/** | ||
* Forks the given mlflow command and awaits for its successful completion. | ||
* | ||
* @param command Command used to fork the process. | ||
* @param tag User-facing tag which will be used to identify what we were trying to do | ||
* in the case of a failure. | ||
* @return raw stdout of the process, decoded as a utf-8 string | ||
* @throws MlflowClientException if the process exits with a non-zero exit code, or anything | ||
* else goes wrong. | ||
*/ | ||
private String forkProcess(List<String> command, String tag) { | ||
String stdout; | ||
Process process = null; | ||
try { | ||
ProcessBuilder pb = new ProcessBuilder(command); | ||
// TODO(aaron) Figure out a way to pass the other fields of the host-creds. | ||
pb.environment().put("MLFLOW_TRACKING_URI", hostCredsProvider.getHostCreds().getHost()); | ||
process = pb.start(); | ||
stdout = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); | ||
int exitValue = process.waitFor(); | ||
if (exitValue != 0) { | ||
throw new MlflowClientException("Failed to " + tag + ". Error: " + | ||
getErrorBestEffort(process)); | ||
} | ||
} catch (IOException | InterruptedException e) { | ||
throw new MlflowClientException("Failed to fork mlflow process to " + tag + | ||
". Process stderr: " + getErrorBestEffort(process), e); | ||
} | ||
return stdout; | ||
} | ||
|
||
/** Does our best to get the process's stderr, or returns a dummy return value. */ | ||
private String getErrorBestEffort(Process process) { | ||
if (process == null) { | ||
return "<process not started>"; | ||
} | ||
try { | ||
return IOUtils.toString(process.getErrorStream(), StandardCharsets.UTF_8); | ||
} catch (IOException e) { | ||
return "<error unknown>"; | ||
} | ||
} | ||
|
||
/** Appends --run-id $runId and --artifact-path $artifactPath, omitting artifactPath if null. */ | ||
private List<String> appendRunIdArtifactPath( | ||
List<String> baseCommand, | ||
String runId, | ||
String artifactPath) { | ||
baseCommand.add("--run-id"); | ||
baseCommand.add(runId); | ||
if (artifactPath != null) { | ||
baseCommand.add("--artifact-path"); | ||
baseCommand.add(artifactPath); | ||
} | ||
return baseCommand; | ||
} | ||
|
||
/** Returns user-facing identifier "runId=abc, artifactId=/foo", omitting artifactPath if null. */ | ||
private String getTargetIdentifier(String artifactPath) { | ||
String identifier = "runId=" + runId; | ||
if (artifactPath != null) { | ||
return identifier + ", artifactPath=" + artifactPath; | ||
} | ||
return identifier; | ||
} | ||
} |
Oops, something went wrong.