Skip to content

Commit

Permalink
Expose ArtifactRepository from MlflowClient (mlflow#438)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Sep 6, 2018
1 parent b722281 commit 5da2afc
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 21 deletions.
2 changes: 1 addition & 1 deletion mlflow/java/client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
<artifactId>maven-javadoc-plugin</artifactId>
<configuration>
<sourcepath>${project.basedir}/src/main/java</sourcepath>
<excludePackageNames>com.databricks.api.proto.databricks:org.mlflow.scalapb_interface:org.mlflow.tracking.samples</excludePackageNames>
<excludePackageNames>com.databricks.api.proto.databricks:org.mlflow.scalapb_interface:org.mlflow.tracking.samples:org.mlflow.artifacts</excludePackageNames>
<groups>
<group>
<title>Tracking API</title>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ public interface ArtifactRepository {
/**
* Uploads the given local file to the run's root artifact directory. For example,
*
* logArtifact("/my/localModel")
* listArtifacts() // returns "localModel"
* <pre>
* logArtifact("/my/localModel")
* listArtifacts() // returns "localModel"
* </pre>
*
* @param localFile File to upload. Must exist, and must be a simple file (not a directory).
*/
Expand All @@ -24,8 +26,10 @@ public interface ArtifactRepository {
/**
* Uploads the given local file to an artifactPath within the run's root directory. For example,
*
* <pre>
* logArtifact("/my/localModel", "model")
* listArtifacts("model") // returns "model/localModel"
* </pre>
*
* (i.e., the localModel file is now available in model/localModel).
*
Expand All @@ -39,8 +43,10 @@ public interface ArtifactRepository {
* Uploads all files within the given local director the run's root artifact directory.
* For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
* logArtifacts("/my/local/dir")
* listArtifacts() // returns "file1" and "file2"
* <pre>
* logArtifacts("/my/local/dir")
* listArtifacts() // returns "file1" and "file2"
* </pre>
*
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
*/
Expand All @@ -51,8 +57,10 @@ public interface ArtifactRepository {
* Uploads all files within the given local director an artifactPath within the run's root
* artifact directory. For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
* logArtifacts("/my/local/dir", "model")
* listArtifacts("model") // returns "model/file1" and "model/file2"
* <pre>
* logArtifacts("/my/local/dir", "model")
* listArtifacts("model") // returns "model/file1" and "model/file2"
* </pre>
*
* (i.e., the contents of the local directory are now available in model/).
*
Expand Down Expand Up @@ -90,8 +98,10 @@ public interface ArtifactRepository {
* within the run's root artifactDirectory. For example, if "model/file1" and "model/file2"
* exist within the artifact directory, then
*
* <pre>
* downloadArtifacts("model") // returns a local directory containing "file1" and "file2"
* downloadArtifacts("model/file1") // returns a local *file* with the contents of file1.
* </pre>
*
* Note that this will download the entire subdirectory path, and so may be expensive if
* the subdirectory a lot of data.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.mlflow.artifacts;

import java.net.URI;

import org.mlflow.tracking.creds.MlflowHostCredsProvider;

public class ArtifactRepositoryFactory {
private final MlflowHostCredsProvider hostCredsProvider;

public ArtifactRepositoryFactory(MlflowHostCredsProvider hostCredsProvider) {
this.hostCredsProvider = hostCredsProvider;
}

public ArtifactRepository getArtifactRepository(URI baseArtifactUri, String runId) {
return new CliBasedArtifactRepository(baseArtifactUri.toString(), runId, hostCredsProvider);
}
}

This file was deleted.

151 changes: 140 additions & 11 deletions mlflow/java/client/src/main/java/org/mlflow/tracking/MlflowClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import org.apache.http.client.utils.URIBuilder;

import org.mlflow.api.proto.Service.*;
import org.mlflow.artifacts.ArtifactRepository;
import org.mlflow.artifacts.ArtifactRepositoryFactory;
import org.mlflow.tracking.creds.*;

import java.io.File;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
Expand All @@ -18,6 +21,7 @@ public class MlflowClient {
private static final long DEFAULT_EXPERIMENT_ID = 0;

private final MlflowProtobufMapper mapper = new MlflowProtobufMapper();
private final ArtifactRepositoryFactory artifactRepositoryFactory;
private final MlflowHttpCaller httpCaller;
private final MlflowHostCredsProvider hostCredsProvider;

Expand All @@ -38,6 +42,7 @@ public MlflowClient(String trackingUri) {
public MlflowClient(MlflowHostCredsProvider hostCredsProvider) {
this.hostCredsProvider = hostCredsProvider;
this.httpCaller = new MlflowHttpCaller(hostCredsProvider);
this. artifactRepositoryFactory = new ArtifactRepositoryFactory(hostCredsProvider);
}

/** @return run associated with the id. */
Expand Down Expand Up @@ -160,14 +165,6 @@ public void setTerminated(String runUuid, RunStatus status, long endTime) {
sendPost("runs/update", mapper.makeUpdateRun(runUuid, status, endTime));
}

/** @return a list of all artifacts under the given artifact path within the run. */
public ListArtifacts.Response listArtifacts(String runUuid, String path) {
URIBuilder builder = newURIBuilder("artifacts/list")
.setParameter("run_uuid", runUuid)
.setParameter("path", path);
return mapper.toListArtifactsResponse(httpCaller.get(builder.toString()));
}

/**
* Send a GET to the following path, including query parameters.
* This is mostly an internal API, but allows making lower-level or unsupported requests.
Expand All @@ -187,10 +184,9 @@ public String sendPost(String path, String json) {
}

/**
* Intended for internal usage, and may be removed in future versions.
* @return HostCredsProvider backing this MlflowClient.
* @return HostCredsProvider backing this MlflowClient. Visible for testing.
*/
public MlflowHostCredsProvider getInternalHostCredsProvider() {
MlflowHostCredsProvider getInternalHostCredsProvider() {
return hostCredsProvider;
}

Expand Down Expand Up @@ -246,4 +242,137 @@ private static MlflowHostCredsProvider getHostCredsProviderFromTrackingUri(Strin
}
return provider;
}

/**
* Uploads the given local file to the run's root artifact directory. For example,
*
* <pre>
* logArtifact(runId, "/my/localModel")
* listArtifacts(runId) // returns "localModel"
* </pre>
*
* @param runId Run ID of an existing MLflow run.
* @param localFile File to upload. Must exist, and must be a simple file (not a directory).
*/
public void logArtifact(String runId, File localFile) {
getArtifactRepository(runId).logArtifact(localFile);
}

/**
* Uploads the given local file to an artifactPath within the run's root directory. For example,
*
* <pre>
* logArtifact(runId, "/my/localModel", "model")
* listArtifacts(runId, "model") // returns "model/localModel"
* </pre>
*
* (i.e., the localModel file is now available in model/localModel).
*
* @param runId Run ID of an existing MLflow run.
* @param localFile File to upload. Must exist, and must be a simple file (not a directory).
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public void logArtifact(String runId, File localFile, String artifactPath) {
getArtifactRepository(runId).logArtifact(localFile, artifactPath);
}

/**
* Uploads all files within the given local directory the run's root artifact directory.
* For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
* <pre>
* logArtifacts(runId, "/my/local/dir")
* listArtifacts(runId) // returns "file1" and "file2"
* </pre>
*
* @param runId Run ID of an existing MLflow run.
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
*/
public void logArtifacts(String runId, File localDir) {
getArtifactRepository(runId).logArtifacts(localDir);
}


/**
* Uploads all files within the given local director an artifactPath within the run's root
* artifact directory. For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
* <pre>
* logArtifacts(runId, "/my/local/dir", "model")
* listArtifacts(runId, "model") // returns "model/file1" and "model/file2"
* </pre>
*
* (i.e., the contents of the local directory are now available in model/).
*
* @param runId Run ID of an existing MLflow run.
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public void logArtifacts(String runId, File localDir, String artifactPath) {
getArtifactRepository(runId).logArtifacts(localDir, artifactPath);
}

/**
* Lists the artifacts immediately under the run's root artifact directory. This does not
* recursively list; instead, it will return FileInfos with isDir=true where further
* listing may be done.
* @param runId Run ID of an existing MLflow run.
*/
public List<FileInfo> listArtifacts(String runId) {
return getArtifactRepository(runId).listArtifacts();
}

/**
* Lists the artifacts immediately under the given artifactPath within the run's root artifact
* directory. This does not recursively list; instead, it will return FileInfos with isDir=true
* where further listing may be done.
* @param runId Run ID of an existing MLflow run.
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public List<FileInfo> listArtifacts(String runId, String artifactPath) {
return getArtifactRepository(runId).listArtifacts(artifactPath);
}

/**
* Returns a local directory containing *all* artifacts within the run's artifact directory.
* Note that this will download the entire directory path, and so may be expensive if
* the directory has a lot of data.
* @param runId Run ID of an existing MLflow run.
*/
public File downloadArtifacts(String runId) {
return getArtifactRepository(runId).downloadArtifacts();
}

/**
* Returns a local file or directory containing all artifacts within the given artifactPath
* within the run's root artifactDirectory. For example, if "model/file1" and "model/file2"
* exist within the artifact directory, then
*
* <pre>
* downloadArtifacts(runId, "model") // returns a local directory containing "file1" and "file2"
* downloadArtifacts(runId, "model/file1") // returns a local *file* with the contents of file1.
* </pre>
*
* Note that this will download the entire subdirectory path, and so may be expensive if
* the subdirectory has a lot of data.
*
* @param runId Run ID of an existing MLflow run.
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public File downloadArtifacts(String runId, String artifactPath) {
return getArtifactRepository(runId).downloadArtifacts(artifactPath);
}

/**
* @param runId Run ID of an existing MLflow run.
* @return ArtifactRepository, capable of uploading and downloading MLflow artifacts.
*/
private ArtifactRepository getArtifactRepository(String runId) {
URI baseArtifactUri = URI.create(getRun(runId).getInfo().getArtifactUri());
return artifactRepositoryFactory.getArtifactRepository(baseArtifactUri, runId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private CliBasedArtifactRepository newRepo() {
logger.info("Created run with id=" + runInfo.getRunUuid() + " and artifactUri=" +
runInfo.getArtifactUri());
return new CliBasedArtifactRepository(runInfo.getArtifactUri(), runInfo.getRunUuid(),
client.getInternalHostCredsProvider());
testClientProvider.getClientHostCredsProvider(client));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package org.mlflow.tracking;

import java.io.*;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.*;

import org.apache.commons.io.FileUtils;
import org.apache.log4j.Logger;
import org.testng.Assert;
import org.testng.annotations.*;

import static org.mlflow.tracking.TestUtils.*;

import org.mlflow.api.proto.Service.*;
import org.mlflow.artifacts.ArtifactRepository;

public class MlflowClientTest {
private static final Logger logger = Logger.getLogger(MlflowClientTest.class);
Expand Down Expand Up @@ -130,4 +135,18 @@ public void checkParamsAndMetrics() {
assertMetric(metrics, "zero_one_loss", ZERO_ONE_LOSS);
assert(metrics.get(0).getTimestamp() > 0) : metrics.get(0).getTimestamp();
}

@Test
public void testUseArtifactRepository() throws IOException {
String content = "Hello, Worldz!";

File tempFile = Files.createTempFile(getClass().getSimpleName(), ".txt").toFile();
FileUtils.writeStringToFile(tempFile, content, StandardCharsets.UTF_8);
client.logArtifact(runId, tempFile);

File downloadedArtifact = client.downloadArtifacts(runId, tempFile.getName());
String downloadedContent = FileUtils.readFileToString(downloadedArtifact,
StandardCharsets.UTF_8);
Assert.assertEquals(content, downloadedContent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

import org.apache.log4j.Logger;

import org.mlflow.tracking.creds.MlflowHostCreds;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;

/**
* Provides an MLflow API client for testing. This is a real client, pointed to a real server.
* If the MLFLOW_TRACKING_URI environment variable is set, we will talk to the provided server;
Expand Down Expand Up @@ -61,6 +64,10 @@ public void cleanupClientAndServer() throws InterruptedException {
}
}

public MlflowHostCredsProvider getClientHostCredsProvider(MlflowClient client) {
return client.getInternalHostCredsProvider();
}

/**
* Launches an "mlflow server" process locally. This requires that the "mlflow" command
* line client is on the local PATH (e.g., that we're within a conda environment), and that
Expand Down

0 comments on commit 5da2afc

Please sign in to comment.