Skip to content

Commit

Permalink
Add LogBatch implementations (client, store) for local use (mlflow#950)
Browse files Browse the repository at this point in the history
Thanks again @aarondav, merging for now - happy to make a quick follow-up if there's anything else I should address :)
  • Loading branch information
smurching committed Mar 7, 2019
1 parent dff9502 commit 70d518b
Show file tree
Hide file tree
Showing 10 changed files with 3,066 additions and 81 deletions.
2,642 changes: 2,576 additions & 66 deletions mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions mlflow/protos/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ service MlflowService {
rpc_doc_title: "Get Metric History",
};
}

// Log a batch of metrics, params, and/or tags for a run. The server will respond with an error
// if any data failed to be persisted.
//
rpc logBatch (LogBatch) returns (LogBatch.Response) {
option (rpc) = {
endpoints: [{
method: "POST",
path: "/preview/mlflow/runs/log-batch"
since { major: 2, minor: 0 },
}],
visibility: PUBLIC,
rpc_doc_title: "Log Batch",
};
}
}

// View type for ListExperiments query.
Expand Down Expand Up @@ -812,3 +827,17 @@ message GetMetricHistory {
repeated Metric metrics = 1;
}
}

message LogBatch {
option (scalapb.message).extends = "com.databricks.rpc.RPC[$this.Response]";
// ID of the run to log under
optional string run_id = 1;
// Metrics to log
repeated Metric metrics = 2;
// Params to log
repeated Param params = 3;
// Tags to log
repeated RunTag tags = 4;
message Response {
}
}
123 changes: 114 additions & 9 deletions mlflow/protos/service_pb2.py

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions mlflow/store/abstract_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,17 @@ def list_run_infos(self, experiment_id, run_view_type):
"""
runs = self.search_runs([experiment_id], None, run_view_type)
return [run.info for run in runs]

@abstractmethod
def log_batch(self, run_id, metrics, params, tags):
"""
Logs multiple metrics, params, and tags for the specified run
:param run_id: String id for the run
:param metrics: List of :py:class:`mlflow.entities.Metric` instances to log
:param params: List of :py:class:`mlflow.entities.Param` instances to log
:param tags: List of :py:class:`mlflow.entities.RunTag` instances to log
:returns Tuple (failed_metrics, failed_params, failed_tags) where each element of
the tuple is a list of of :py:class:`mlflow.protos.service_pb2.BatchLogFailure`
protos describing metrics/params/tags that failed to be logged & why.
"""
pass
2 changes: 1 addition & 1 deletion mlflow/store/dbmodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class SqlMetric(Base):

key = Column(String(250))
value = Column(Float, nullable=False)
timestamp = Column(BigInteger, default=int(time.time()))
timestamp = Column(BigInteger, default=lambda: int(time.time()))
run_uuid = Column(String(32), ForeignKey('runs.run_uuid'))
run = relationship('SqlRun', backref=backref('metrics', cascade='all'))

Expand Down
15 changes: 15 additions & 0 deletions mlflow/store/file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlflow.entities.run_info import check_run_is_active, check_run_is_deleted
from mlflow.exceptions import MlflowException, MissingConfigException
import mlflow.protos.databricks_pb2 as databricks_pb2
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
from mlflow.store import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
from mlflow.store.abstract_store import AbstractStore
from mlflow.utils.validation import _validate_metric_name, _validate_param_name, _validate_run_id, \
Expand Down Expand Up @@ -533,3 +534,17 @@ def _overwrite_run_info(self, run_info):
run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_uuid)
run_info_dict = _make_persisted_run_info_dict(run_info)
write_yaml(run_dir, FileStore.META_DATA_FILE_NAME, run_info_dict, overwrite=True)

def log_batch(self, run_id, metrics, params, tags):
_validate_run_id(run_id)
run = self.get_run(run_id)
check_run_is_active(run.info)
try:
for param in params:
self.log_param(run_id, param)
for metric in metrics:
self.log_metric(run_id, metric)
for tag in tags:
self.set_tag(run_id, tag)
except Exception as e:
raise MlflowException(e, INTERNAL_ERROR)
4 changes: 4 additions & 0 deletions mlflow/store/rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mlflow.store.abstract_store import AbstractStore

from mlflow.entities import Experiment, Run, RunInfo, RunTag, Metric, ViewType
from mlflow.exceptions import MlflowException

from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
Expand Down Expand Up @@ -233,3 +234,6 @@ def delete_run(self, run_id):
def restore_run(self, run_id):
req_body = message_to_json(RestoreRun(run_id=run_id))
self._call_endpoint(RestoreRun, req_body)

def log_batch(self, run_id, metrics, params, tags):
raise MlflowException("The LogBatch REST API is not yet implemented")
27 changes: 22 additions & 5 deletions mlflow/store/sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mlflow.tracking.utils import _is_local_uri
from mlflow.utils.file_utils import build_path, mkdir
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_RUN_NAME
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR


class SqlAlchemyStore(AbstractStore):
Expand Down Expand Up @@ -86,7 +87,6 @@ def _save_to_db(self, objs):
"""
Store in db
"""

if type(objs) is list:
self.session.add_all(objs)
else:
Expand Down Expand Up @@ -240,12 +240,14 @@ def _get_run(self, run_uuid):
def _check_run_is_active(self, run):
if run.lifecycle_stage != LifecycleStage.ACTIVE:
raise MlflowException("The run {} must be in 'active' state. Current state is {}."
.format(run.run_uuid, run.lifecycle_stage))
.format(run.run_uuid, run.lifecycle_stage),
INVALID_PARAMETER_VALUE)

def _check_run_is_deleted(self, run):
if run.lifecycle_stage != LifecycleStage.DELETED:
raise MlflowException("The run {} must be in 'deleted' state. Current state is {}."
.format(run.run_uuid, run.lifecycle_stage))
.format(run.run_uuid, run.lifecycle_stage),
INVALID_PARAMETER_VALUE)

def update_run_info(self, run_uuid, run_status, end_time):
run = self._get_run(run_uuid)
Expand Down Expand Up @@ -334,8 +336,8 @@ def log_param(self, run_uuid, param):
def set_tag(self, run_uuid, tag):
run = self._get_run(run_uuid)
self._check_run_is_active(run)
new_tag = SqlTag(run_uuid=run_uuid, key=tag.key, value=tag.value)
self._save_to_db(new_tag)
self.session.merge(SqlTag(run_uuid=run_uuid, key=tag.key, value=tag.value))
self.session.commit()

def search_runs(self, experiment_ids, search_filter, run_view_type):
runs = [run.to_mlflow_entity()
Expand All @@ -347,3 +349,18 @@ def _list_runs(self, experiment_id, run_view_type):
exp = self._list_experiments(ids=[experiment_id], view_type=ViewType.ALL).first()
stages = set(LifecycleStage.view_type_to_stages(run_view_type))
return [run for run in exp.runs if run.lifecycle_stage in stages]

def log_batch(self, run_id, metrics, params, tags):
run = self._get_run(run_id)
self._check_run_is_active(run)
try:
for param in params:
self.log_param(run_id, param)
for metric in metrics:
self.log_metric(run_id, metric)
for tag in tags:
self.set_tag(run_id, tag)
except MlflowException as e:
raise e
except Exception as e:
raise MlflowException(e, INTERNAL_ERROR)
129 changes: 129 additions & 0 deletions tests/store/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlflow.exceptions import MlflowException, MissingConfigException
from mlflow.store.file_store import FileStore
from mlflow.utils.file_utils import write_yaml, read_yaml
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST, INTERNAL_ERROR
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
from tests.helper_functions import random_int, random_str

Expand Down Expand Up @@ -520,3 +521,131 @@ def test_bad_experiment_id_recorded_for_run(self):
for rid in all_run_ids:
if rid != bad_run_id:
fs.get_run(rid)

def test_log_batch(self):
fs = FileStore(self.test_root)
run = fs.create_run(
experiment_id=Experiment.DEFAULT_EXPERIMENT_ID, user_id='user', run_name=None,
source_type='source_type', source_name='source_name',
entry_point_name='entry_point_name', start_time=0, source_version=None, tags=[],
parent_run_id=None)
run_uuid = run.info.run_uuid
metric_entities = [Metric("m1", 0.87, 12345), Metric("m2", 0.49, 12345)]
param_entities = [Param("p1", "p1val"), Param("p2", "p2val")]
tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
fs.log_batch(
run_id=run_uuid, metrics=metric_entities, params=param_entities, tags=tag_entities)
run = fs.get_run(run_uuid)
tags = [(t.key, t.value) for t in run.data.tags]
metrics = [(m.key, m.value, m.timestamp) for m in run.data.metrics]
params = [(p.key, p.value) for p in run.data.params]
assert set(tags) == set([("t1", "t1val"), ("t2", "t2val")])
assert set(metrics) == set([("m1", 0.87, 12345), ("m2", 0.49, 12345)])
assert set(params) == set([("p1", "p1val"), ("p2", "p2val")])

def _create_run(self, fs):
return fs.create_run(
experiment_id=Experiment.DEFAULT_EXPERIMENT_ID, user_id='user', run_name=None,
source_type='source_type', source_name='source_name',
entry_point_name='entry_point_name', start_time=0, source_version=None, tags=[],
parent_run_id=None)

def _verify_logged(self, fs, run_uuid, metrics, params, tags):
run = fs.get_run(run_uuid)
all_metrics = sum([fs.get_metric_history(run_uuid, m.key)
for m in run.data.metrics], [])
assert len(all_metrics) == len(metrics)
logged_metrics = [(m.key, m.value, m.timestamp) for m in all_metrics]
assert set(logged_metrics) == set([(m.key, m.value, m.timestamp) for m in metrics])
assert len(run.data.tags) == len(tags)
logged_tags = [(tag.key, tag.value) for tag in run.data.tags]
assert set(logged_tags) == set([(tag.key, tag.value) for tag in tags])
assert len(run.data.params) == len(params)
logged_params = [(param.key, param.value) for param in run.data.params]
assert set(logged_params) == set([(param.key, param.value) for param in params])

def test_log_batch_internal_error(self):
# Verify that internal errors during log_batch result in MlflowExceptions
fs = FileStore(self.test_root)
run = self._create_run(fs)

def _raise_exception_fn(*args, **kwargs): # pylint: disable=unused-argument
raise Exception("Some internal error")
with mock.patch("mlflow.store.file_store.FileStore.log_metric") as log_metric_mock, \
mock.patch("mlflow.store.file_store.FileStore.log_param") as log_param_mock, \
mock.patch("mlflow.store.file_store.FileStore.set_tag") as set_tag_mock:
log_metric_mock.side_effect = _raise_exception_fn
log_param_mock.side_effect = _raise_exception_fn
set_tag_mock.side_effect = _raise_exception_fn
for kwargs in [{"metrics": [Metric("a", 3, 1)]}, {"params": [Param("b", "c")]},
{"tags": [RunTag("c", "d")]}]:
log_batch_kwargs = {"metrics": [], "params": [], "tags": []}
log_batch_kwargs.update(kwargs)
print(log_batch_kwargs)
with self.assertRaises(MlflowException) as e:
fs.log_batch(run.info.run_uuid, **log_batch_kwargs)
self.assertIn(str(e.exception.message), "Some internal error")
assert e.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)

def test_log_batch_nonexistent_run(self):
fs = FileStore(self.test_root)
nonexistent_uuid = uuid.uuid4().hex
with self.assertRaises(MlflowException) as e:
fs.log_batch(nonexistent_uuid, [], [], [])
assert e.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
assert ("Run '%s' not found" % nonexistent_uuid) in e.exception.message

def test_log_batch_params_idempotency(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
params = [Param("p-key", "p-val")]
fs.log_batch(run.info.run_uuid, metrics=[], params=params, tags=[])
fs.log_batch(run.info.run_uuid, metrics=[], params=params, tags=[])
self._verify_logged(fs, run.info.run_uuid, metrics=[], params=params, tags=[])

def test_log_batch_tags_idempotency(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
fs.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")])
fs.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")])
self._verify_logged(fs, run.info.run_uuid, metrics=[], params=[],
tags=[RunTag("t-key", "t-val")])

def test_log_batch_allows_tag_overwrite(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
fs.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "val")])
fs.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")])
self._verify_logged(fs, run.info.run_uuid, metrics=[], params=[],
tags=[RunTag("t-key", "newval")])

def test_log_batch_same_metric_repeated_single_req(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
metric0 = Metric(key="metric-key", value=1, timestamp=2)
metric1 = Metric(key="metric-key", value=2, timestamp=3)
fs.log_batch(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])
self._verify_logged(fs, run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])

def test_log_batch_same_metric_repeated_multiple_reqs(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
metric0 = Metric(key="metric-key", value=1, timestamp=2)
metric1 = Metric(key="metric-key", value=2, timestamp=3)
fs.log_batch(run.info.run_uuid, params=[], metrics=[metric0], tags=[])
self._verify_logged(fs, run.info.run_uuid, params=[], metrics=[metric0], tags=[])
fs.log_batch(run.info.run_uuid, params=[], metrics=[metric1], tags=[])
self._verify_logged(fs, run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])

def test_log_batch_allows_tag_overwrite_single_req(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
tags = [RunTag("t-key", "val"), RunTag("t-key", "newval")]
fs.log_batch(run.info.run_uuid, metrics=[], params=[], tags=tags)
self._verify_logged(fs, run.info.run_uuid, metrics=[], params=[], tags=[tags[-1]])

def test_log_batch_accepts_empty_payload(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
fs.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[])
self._verify_logged(fs, run.info.run_uuid, metrics=[], params=[], tags=[])
Loading

0 comments on commit 70d518b

Please sign in to comment.