Skip to content

Commit

Permalink
Update java searchRuns to return list of Runs instead of RunInfos (ml…
Browse files Browse the repository at this point in the history
…flow#1518)

* Change Java searchRuns to return list of Runs, not RunInfos
  • Loading branch information
max-allen-db committed Jun 28, 2019
1 parent 4d9f7f7 commit bc83869
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ public RunInfo createRun(CreateRun request) {
public List<RunInfo> listRunInfos(String experimentId) {
List<String> experimentIds = new ArrayList<>();
experimentIds.add(experimentId);
return searchRuns(experimentIds, null);
return searchRuns(experimentIds, null).stream().map(Run::getInfo)
.collect(Collectors.toList());
}

/**
Expand All @@ -130,9 +131,9 @@ public List<RunInfo> listRunInfos(String experimentId) {
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc = 0.9"
*
* @return A list of all RunInfos that satisfy search filter.
* @return A list of all Runs that satisfy search filter.
*/
public List<RunInfo> searchRuns(List<String> experimentIds, String searchFilter) {
public List<Run> searchRuns(List<String> experimentIds, String searchFilter) {
return searchRuns(experimentIds, searchFilter, ViewType.ACTIVE_ONLY);
}

Expand All @@ -146,9 +147,9 @@ public List<RunInfo> searchRuns(List<String> experimentIds, String searchFilter)
* @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL)
* Defaults to ACTIVE_ONLY.
*
* @return A list of all RunInfos that satisfy search filter.
* @return A list of all Runs that satisfy search filter.
*/
public List<RunInfo> searchRuns(List<String> experimentIds,
public List<Run> searchRuns(List<String> experimentIds,
String searchFilter,
ViewType runViewType) {
return searchRuns(experimentIds, searchFilter, runViewType, new ArrayList<>());
Expand All @@ -168,7 +169,7 @@ public List<RunInfo> searchRuns(List<String> experimentIds,
*
* @return A list of all RunInfos that satisfy search filter.
*/
public List<RunInfo> searchRuns(List<String> experimentIds,
public List<Run> searchRuns(List<String> experimentIds,
String searchFilter,
ViewType runViewType,
List<String> orderBy) {
Expand All @@ -185,8 +186,7 @@ public List<RunInfo> searchRuns(List<String> experimentIds,
SearchRuns request = builder.build();
String ijson = mapper.toJson(request);
String ojson = sendPost("runs/search", ijson);
return mapper.toSearchRunsResponse(ojson).getRunsList().stream().map(Run::getInfo)
.collect(Collectors.toList());
return mapper.toSearchRunsResponse(ojson).getRunsList();
}

/** @return A list of all experiments. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ public void searchRuns() {
List<String> experimentIds = Arrays.asList(expId);

// metrics based searches
List<RunInfo> searchResult = client.searchRuns(experimentIds, "metrics.accuracy_score < 0");
List<Run> searchResult = client.searchRuns(experimentIds, "metrics.accuracy_score < 0");
Assert.assertEquals(searchResult.size(), 0);

searchResult = client.searchRuns(experimentIds, "metrics.accuracy_score > 0");
Expand All @@ -221,10 +221,16 @@ public void searchRuns() {
Assert.assertEquals(searchResult.size(), 0);

searchResult = client.searchRuns(experimentIds, "metrics.accuracy_score < 0.5");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getData().getMetricsList().size(), 1);
Assert.assertEquals(searchResult.get(0).getData().getParamsList().size(), 2);
Assert.assertEquals(searchResult.get(0).getData().getTagsList().size(), 2);

searchResult = client.searchRuns(experimentIds, "metrics.accuracy_score > 0.5");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_2);
Assert.assertEquals(searchResult.get(0).getData().getMetricsList().size(), 1);
Assert.assertEquals(searchResult.get(0).getData().getParamsList().size(), 2);
Assert.assertEquals(searchResult.get(0).getData().getTagsList().size(), 1);

// parameter based searches
searchResult = client.searchRuns(experimentIds,
Expand All @@ -234,33 +240,33 @@ public void searchRuns() {
"params.min_samples_leaf != '" + MIN_SAMPLES_LEAF + "'");
Assert.assertEquals(searchResult.size(), 0);
searchResult = client.searchRuns(experimentIds, "params.max_depth = '5'");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_1);

searchResult = client.searchRuns(experimentIds, "params.max_depth = '15'");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_2);

// tag based search
searchResult = client.searchRuns(experimentIds, "tag.user_email = '" + USER_EMAIL + "'");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_1);

searchResult = client.searchRuns(experimentIds, "tag.user_email != '" + USER_EMAIL + "'");
Assert.assertEquals(searchResult.size(), 0);

searchResult = client.searchRuns(experimentIds, "tag.test = 'works'");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_1);

searchResult = client.searchRuns(experimentIds, "tag.test = 'also works'");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_2);

searchResult = client.searchRuns(experimentIds, "", ViewType.ACTIVE_ONLY,
Lists.newArrayList("metrics.accuracy_score"));
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(1).getRunUuid(), runId_2);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(1).getInfo().getRunUuid(), runId_2);

searchResult = client.searchRuns(experimentIds, "", ViewType.ACTIVE_ONLY,
Lists.newArrayList("params.min_samples_leaf", "metrics.accuracy_score DESC"));
Assert.assertEquals(searchResult.get(1).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2);
Assert.assertEquals(searchResult.get(1).getInfo().getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getInfo().getRunUuid(), runId_2);
}

@Test
Expand Down

0 comments on commit bc83869

Please sign in to comment.