Skip to content

Commit

Permalink
Add 'mlflow artifacts' CLI to list, download, and upload to artifact …
Browse files Browse the repository at this point in the history
…repos (mlflow#391)
  • Loading branch information
aarondav committed Aug 28, 2018
1 parent cb80239 commit 643ab70
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 47 deletions.
2 changes: 2 additions & 0 deletions mlflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mlflow.utils import cli_args
from mlflow.server import _run_server
from mlflow import tracking
import mlflow.store.cli


@click.group()
Expand Down Expand Up @@ -194,6 +195,7 @@ def server(file_store, default_artifact_root, host, port, workers, static_prefix
cli.add_command(mlflow.sagemaker.cli.commands)
cli.add_command(mlflow.azureml.cli.commands)
cli.add_command(mlflow.experiments.commands)
cli.add_command(mlflow.store.cli.commands)

if __name__ == '__main__':
cli()
43 changes: 19 additions & 24 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import six

from flask import Response, request, send_file
from google.protobuf.json_format import MessageToJson, ParseDict
from querystring_parser import parser

from mlflow.entities import Metric, Param, RunTag
Expand All @@ -16,6 +15,7 @@
DeleteExperiment, RestoreExperiment
from mlflow.store.artifact_repo import ArtifactRepository
from mlflow.store.file_store import FileStore
from mlflow.utils.proto_json_utils import message_to_json, parse_dict


_store = None
Expand All @@ -40,7 +40,7 @@ def _get_request_message(request_message, flask_request=request):
# result.
query_string = re.sub('%5B%5D', '%5B0%5D', flask_request.query_string.decode("utf-8"))
request_dict = parser.parse(query_string, normalized=True)
ParseDict(request_dict, request_message, ignore_unknown_fields=True)
parse_dict(request_dict, request_message)
return request_message

request_json = flask_request.get_json(force=True, silent=True)
Expand All @@ -55,7 +55,7 @@ def _get_request_message(request_message, flask_request=request):
# If request doesn't have json body then assume it's empty.
if request_json is None:
request_json = {}
ParseDict(request_json, request_message, ignore_unknown_fields=True)
parse_dict(request_json, request_message)
return request_message


Expand Down Expand Up @@ -89,19 +89,14 @@ def _not_implemented():
return response


def _message_to_json(message):
# preserving_proto_field_name keeps the JSON-serialized form snake_case
return MessageToJson(message, preserving_proto_field_name=True)


def _create_experiment():
request_message = _get_request_message(CreateExperiment())
experiment_id = _get_store().create_experiment(request_message.name,
request_message.artifact_location)
response_message = CreateExperiment.Response()
response_message.experiment_id = experiment_id
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -113,7 +108,7 @@ def _get_experiment():
run_info_entities = _get_store().list_run_infos(request_message.experiment_id)
response_message.runs.extend([r.to_proto() for r in run_info_entities])
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -122,7 +117,7 @@ def _delete_experiment():
_get_store().delete_experiment(request_message.experiment_id)
response_message = DeleteExperiment.Response()
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -131,7 +126,7 @@ def _restore_experiment():
_get_store().restore_experiment(request_message.experiment_id)
response_message = RestoreExperiment.Response()
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -153,7 +148,7 @@ def _create_run():
response_message = CreateRun.Response()
response_message.run.MergeFrom(run.to_proto())
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -163,7 +158,7 @@ def _update_run():
request_message.end_time)
response_message = UpdateRun.Response(run_info=updated_info.to_proto())
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -173,7 +168,7 @@ def _log_metric():
_get_store().log_metric(request_message.run_uuid, metric)
response_message = LogMetric.Response()
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -183,7 +178,7 @@ def _log_param():
_get_store().log_param(request_message.run_uuid, param)
response_message = LogParam.Response()
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -193,7 +188,7 @@ def _set_tag():
_get_store().set_tag(request_message.run_uuid, tag)
response_message = SetTag.Response()
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -202,7 +197,7 @@ def _get_run():
response_message = GetRun.Response()
response_message.run.MergeFrom(_get_store().get_run(request_message.run_uuid).to_proto())
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -213,7 +208,7 @@ def _search_runs():
request_message.anded_expressions)
response_message.runs.extend([r.to_proto() for r in run_entities])
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -229,7 +224,7 @@ def _list_artifacts():
response_message.files.extend([a.to_proto() for a in artifact_entities])
response_message.root_uri = _get_artifact_repo(run).artifact_uri
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -240,7 +235,7 @@ def _get_metric_history():
request_message.metric_key)
response_message.metrics.extend([m.to_proto() for m in metric_entites])
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -250,7 +245,7 @@ def _get_metric():
metric = _get_store().get_metric(request_message.run_uuid, request_message.metric_key)
response_message.metric.MergeFrom(metric.to_proto())
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -260,7 +255,7 @@ def _get_param():
parameter = _get_store().get_param(request_message.run_uuid, request_message.param_name)
response_message.parameter.MergeFrom(parameter.to_proto())
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand All @@ -270,7 +265,7 @@ def _list_experiments():
response_message = ListExperiments.Response()
response_message.experiments.extend([e.to_proto() for e in experiment_entities])
response = Response(mimetype='application/json')
response.set_data(_message_to_json(response_message))
response.set_data(message_to_json(response_message))
return response


Expand Down
1 change: 1 addition & 0 deletions mlflow/store/artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def download_artifacts(self, artifact_path):
"""
Download an artifact file or directory to a local directory if applicable, and return a
local path for it.
The caller is responsible for managing the lifecycle of the downloaded artifacts.
:param path: Relative source path to the desired artifact
:return: Full path desired artifact.
"""
Expand Down
96 changes: 96 additions & 0 deletions mlflow/store/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from mlflow.utils.logging_utils import eprint

import click

from mlflow.tracking import _get_store
from mlflow.store.artifact_repo import ArtifactRepository
from mlflow.utils.proto_json_utils import message_to_json


@click.group("artifacts")
def commands():
"""Upload, list, and download artifacts from an MLflow artifact repository."""
pass


@commands.command("log-artifact")
@click.option("--local-file", "-l", required=True,
help="Local path to artifact to log")
@click.option("--run-id", "-r", required=True,
help="Run ID into which we should log the artifact.")
@click.option("--artifact-path", "-a",
help="If specified, we will log the artifact into this subdirectory of the " +
"run's artifact directory.")
def log_artifact(local_file, run_id, artifact_path):
"""
Logs a local file as an artifact of a run, optionally within a run-specific
artifact path. Run artifacts can be organized into directories, so you can
place the artifact in a directory this way.
"""
store = _get_store()
artifact_uri = store.get_run(run_id).info.artifact_uri
artifact_repo = ArtifactRepository.from_artifact_uri(artifact_uri, store)
artifact_repo.log_artifact(local_file, artifact_path)
eprint("Logged artifact from local file %s to artifact_path=%s" % (local_file, artifact_path))


@commands.command("log-artifacts")
@click.option("--local-dir", "-l", required=True,
help="Directory of local artifacts to log")
@click.option("--run-id", "-r", required=True,
help="Run ID into which we should log the artifact.")
@click.option("--artifact-path", "-a",
help="If specified, we will log the artifact into this subdirectory of the " +
"run's artifact directory.")
def log_artifacts(local_dir, run_id, artifact_path):
"""
Logs the files within a local directory as an artifact of a run, optionally
within a run-specific artifact path. Run artifacts can be organized into
directories, so you can place the artifact in a directory this way.
"""
store = _get_store()
artifact_uri = store.get_run(run_id).info.artifact_uri
artifact_repo = ArtifactRepository.from_artifact_uri(artifact_uri, store)
artifact_repo.log_artifacts(local_dir, artifact_path)
eprint("Logged artifact from local dir %s to artifact_path=%s" % (local_dir, artifact_path))


@commands.command("list")
@click.option("--run-id", "-r", required=True,
help="Run ID to be listed")
@click.option("--artifact-path", "-a",
help="If specified, a path relative to the run's root directory to list.")
def list_artifacts(run_id, artifact_path):
"""
Return all the artifacts directly under run's root artifact directory,
or a sub-directory. The output is a JSON-formatted list.
"""
artifact_path = artifact_path if artifact_path is not None else ""
store = _get_store()
artifact_uri = store.get_run(run_id).info.artifact_uri
artifact_repo = ArtifactRepository.from_artifact_uri(artifact_uri, store)
file_infos = artifact_repo.list_artifacts(artifact_path)
print(_file_infos_to_json(file_infos))


def _file_infos_to_json(file_infos):
json_list = [message_to_json(file_info.to_proto()) for file_info in file_infos]
return "[" + ", ".join(json_list) + "]"


@commands.command("download")
@click.option("--run-id", "-r", required=True,
help="Run ID from which to download")
@click.option("--artifact-path", "-a",
help="If specified, a path relative to the run's root directory to download")
def download_artifacts(run_id, artifact_path):
"""
Download an artifact file or directory to a local directory.
The output is the name of the file or directory on the local disk.
"""
artifact_path = artifact_path if artifact_path is not None else ""
store = _get_store()
artifact_uri = store.get_run(run_id).info.artifact_uri
artifact_repo = ArtifactRepository.from_artifact_uri(artifact_uri, store)
artifact_location = artifact_repo.download_artifacts(artifact_path)
print(artifact_location)
Loading

0 comments on commit 643ab70

Please sign in to comment.