Skip to content

Commit

Permalink
Add Python fluent & client batched-logging APIs & unit tests (mlflow#955
Browse files Browse the repository at this point in the history
)

* Add local batch-logging functionality

* Copy over proto changes

* Fix some linter errors & add / fix some more DB tests

* Reemove fluent API & add FileStore tests

* Remove change to duplicate-timestamp behavior in SQL log_metric (to be discussed offline first)

* Remove unnecessary line deletion

* Remove validation

* Remove use of validation helpers

* Remove unneeded exception-error-code-change

* Remove validations for now

* Fix test + improve abstract store signature of log_batch (use run_id instead of run_uuid)

* Use run_id instead of run_uuid in client API

* Add new store tests, remove TODO

* Undo all changes specific to this PR

* Revert "Undo all changes specific to this PR"

This reverts commit 7a3e629.

* Add more tests for fluent API + logic for entity equality & tests

* Test inequality case too

* Undo entity equality changes

* Fix tests w/o entity equality changes

* Fix lint

* Address review comments

* Address review comments / fix tests after merging master
  • Loading branch information
smurching authored and dbczumar committed Mar 12, 2019
1 parent 1581fd4 commit af714c2
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 22 deletions.
3 changes: 3 additions & 0 deletions mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
get_tracking_uri = tracking.get_tracking_uri
create_experiment = mlflow.tracking.fluent.create_experiment
set_experiment = mlflow.tracking.fluent.set_experiment
log_params = mlflow.tracking.fluent.log_params
log_metrics = mlflow.tracking.fluent.log_metrics
set_tags = mlflow.tracking.fluent.set_tags


run = projects.run
Expand Down
25 changes: 22 additions & 3 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from six import iteritems

from mlflow.tracking import utils
from mlflow.utils.validation import _validate_metric_name, _validate_param_name, \
_validate_tag_name, _validate_run_id, _validate_experiment_name
from mlflow.utils.validation import _validate_param_name, _validate_tag_name, _validate_run_id, \
_validate_experiment_name, _validate_metric
from mlflow.entities import Param, Metric, RunStatus, RunTag, ViewType, SourceType
from mlflow.store.artifact_repository_registry import get_artifact_repository

Expand Down Expand Up @@ -131,8 +131,8 @@ def log_metric(self, run_id, key, value, timestamp=None):
Log a metric against the run ID. If timestamp is not provided, uses
the current timestamp.
"""
_validate_metric_name(key)
timestamp = timestamp if timestamp is not None else int(time.time())
_validate_metric(key, value, timestamp)
metric = Metric(key, value, timestamp)
self.store.log_metric(run_id, metric)

Expand All @@ -152,6 +152,25 @@ def set_tag(self, run_id, key, value):
tag = RunTag(key, str(value))
self.store.set_tag(run_id, tag)

def log_batch(self, run_id, metrics, params, tags):
"""
Log multiple metrics, params, and/or tags.
:param metrics: List of Metric(key, value, timestamp) instances.
:param params: List of Param(key, value) instances.
:param tags: List of RunTag(key, value) instances.
Raises an MlflowException if any errors occur.
:returns: None
"""
for metric in metrics:
_validate_metric(metric.key, metric.value, metric.timestamp)
for param in params:
_validate_param_name(param.key)
for tag in tags:
_validate_tag_name(tag.key)
self.store.log_batch(run_id=run_id, metrics=metrics, params=params, tags=tags)

def log_artifact(self, run_id, local_path, artifact_path=None):
"""
Write a local file to the remote ``artifact_uri``.
Expand Down
43 changes: 37 additions & 6 deletions mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@

from __future__ import print_function

import numbers
import os

import atexit
import time
import logging

import mlflow.tracking.utils
from mlflow.entities import Experiment, Run, RunStatus, SourceType
from mlflow.entities import Experiment, Run, SourceType, RunStatus, Param, RunTag, Metric
from mlflow.entities.lifecycle_stage import LifecycleStage
from mlflow.exceptions import MlflowException
from mlflow.tracking.client import MlflowClient
Expand Down Expand Up @@ -201,14 +200,46 @@ def log_metric(key, value):
:param key: Metric name (string).
:param value: Metric value (float).
"""
if not isinstance(value, numbers.Number):
_logger.warning(
"The metric %s=%s was not logged because the value is not a number.", key, value)
return
run_id = _get_or_start_run().info.run_uuid
MlflowClient().log_metric(run_id, key, value, int(time.time()))


def log_metrics(metrics):
"""
Log multiple metrics for the current run, starting a run if no runs are active.
:param metrics: Dictionary of metric_name: String -> value: Float
:returns: None
"""
run_id = _get_or_start_run().info.run_uuid
timestamp = int(time.time())
metrics_arr = [Metric(key, value, timestamp) for key, value in metrics.items()]
MlflowClient().log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])


def log_params(params):
"""
Log a batch of params for the current run, starting a run if no runs are active.
:param params: Dictionary of param_name: String -> value: (String, but will be string-ified if
not)
:returns: None
"""
run_id = _get_or_start_run().info.run_uuid
params_arr = [Param(key, value) for key, value in params.items()]
MlflowClient().log_batch(run_id=run_id, metrics=[], params=params_arr, tags=[])


def set_tags(tags):
"""
Log a batch of tags for the current run, starting a run if no runs are active.
:param tags: Dictionary of tag_name: String -> value: (String, but will be string-ified if
not)
:returns: None
"""
run_id = _get_or_start_run().info.run_uuid
tags_arr = [RunTag(key, value) for key, value in tags.items()]
MlflowClient().log_batch(run_id=run_id, metrics=[], params=[], tags=tags_arr)


def log_artifact(local_path, artifact_path=None):
"""
Log a local file or directory as an artifact of the currently active run.
Expand Down
38 changes: 32 additions & 6 deletions mlflow/utils/validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""
Utilities for validating user inputs such as metric names and parameter names.
"""
import numbers
import os.path
import re

import numpy as np

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE

Expand Down Expand Up @@ -33,26 +36,49 @@ def path_not_unique(name):
def _validate_metric_name(name):
"""Check that `name` is a valid metric name and raise an exception if it isn't."""
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
raise Exception("Invalid metric name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE))
raise MlflowException("Invalid metric name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE),
INVALID_PARAMETER_VALUE)
if path_not_unique(name):
raise Exception("Invalid metric name: '%s'. %s" % (name, bad_path_message(name)))
raise MlflowException("Invalid metric name: '%s'. %s" % (name, bad_path_message(name)),
INVALID_PARAMETER_VALUE)


def _validate_metric(key, value, timestamp):
_validate_metric_name(key)
if not isinstance(value, numbers.Number) or value > np.finfo(np.float64).max \
or value < np.finfo(np.float64).min:
raise MlflowException(
"Got invalid value %s for metric '%s' (timestamp=%s). Please specify value as a valid "
"double (64-bit floating point)" % (value, key, timestamp),
INVALID_PARAMETER_VALUE)

if not isinstance(timestamp, numbers.Number) or timestamp < 0 or \
timestamp < np.iinfo(np.int64).min:
raise MlflowException(
"Got invalid timestamp %s for metric '%s' (value=%s). Timestamp must be a nonnegative "
"long (64-bit integer) " % (timestamp, key, value),
INVALID_PARAMETER_VALUE)


def _validate_param_name(name):
"""Check that `name` is a valid parameter name and raise an exception if it isn't."""
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
raise Exception("Invalid parameter name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE))
raise MlflowException("Invalid parameter name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE),
INVALID_PARAMETER_VALUE)
if path_not_unique(name):
raise Exception("Invalid parameter name: '%s'. %s" % (name, bad_path_message(name)))
raise MlflowException("Invalid parameter name: '%s'. %s" % (name, bad_path_message(name)),
INVALID_PARAMETER_VALUE)


def _validate_tag_name(name):
"""Check that `name` is a valid tag name and raise an exception if it isn't."""
# Reuse param & metric check.
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
raise Exception("Invalid tag name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE))
raise MlflowException("Invalid tag name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE),
INVALID_PARAMETER_VALUE)
if path_not_unique(name):
raise Exception("Invalid tag name: '%s'. %s" % (name, bad_path_message(name)))
raise MlflowException("Invalid tag name: '%s'. %s" % (name, bad_path_message(name)),
INVALID_PARAMETER_VALUE)


def _validate_run_id(run_id):
Expand Down
120 changes: 117 additions & 3 deletions tests/tracking/test_tracking.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import copy
import filecmp
import os
import random
import tempfile
import time

import attrdict
import mock
import pytest

import mlflow
from mlflow import tracking
from mlflow.entities import RunStatus, LifecycleStage
from mlflow.entities import RunStatus, LifecycleStage, Metric, Param, RunTag
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE
from mlflow.tracking.client import MlflowClient
from mlflow.tracking.fluent import start_run, end_run
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE
from tests.projects.utils import tracking_uri_mock


Expand Down Expand Up @@ -148,6 +151,40 @@ def test_start_and_end_run(tracking_uri_mock):
assert expected_pairs[metric.key] == metric.value


def test_log_batch(tracking_uri_mock):
expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"}
exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])

t = int(time.time())
metrics = [Metric(key=key, value=value, timestamp=t) for key, value in expected_metrics.items()]
params = [Param(key=key, value=value) for key, value in expected_params.items()]
tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()]

active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
mlflow.tracking.MlflowClient().log_batch(run_id=run_uuid, metrics=metrics, params=params,
tags=tags)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate metrics
assert len(finished_run.data.metrics) == 2
for metric in finished_run.data.metrics:
assert expected_metrics[metric.key] == metric.value
# Validate tags (for automatically-set tags)
assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
for tag in finished_run.data.tags:
if tag.key in approx_expected_tags:
pass
else:
assert exact_expected_tags[tag.key] == tag.value
# Validate params
assert len(finished_run.data.params) == 2
for param in finished_run.data.params:
assert expected_params[param.key] == param.value


def test_log_metric(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
Expand All @@ -164,11 +201,52 @@ def test_log_metric(tracking_uri_mock):
assert expected_pairs[metric.key] == metric.value


def test_log_metric_validation(tracking_uri_mock):
def test_log_metrics(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
with active_run:
mlflow.log_metric("name_1", 25)
mlflow.log_metric("nested/nested/name", 45)
mlflow.log_metrics(expected_metrics)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate metric key/values match what we expect, and that all metrics have the same timestamp
common_timestamp = finished_run.data.metrics[0].timestamp
assert len(finished_run.data.metrics) == 3
for metric in finished_run.data.metrics:
assert expected_metrics[metric.key] == metric.value
assert metric.timestamp == common_timestamp


@pytest.fixture
def get_store_mock(tmpdir):
with mock.patch("mlflow.store.file_store.FileStore.log_batch") as _get_store_mock:
yield _get_store_mock


def test_set_tags(tracking_uri_mock):
exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
mlflow.set_tags(exact_expected_tags)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate tags
assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
for tag in finished_run.data.tags:
if tag.key in approx_expected_tags:
pass
else:
assert exact_expected_tags[tag.key] == tag.value


def test_log_metric_validation(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run, pytest.raises(MlflowException) as e:
mlflow.log_metric("name_1", "apple")
assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
finished_run = tracking.MlflowClient().get_run(run_uuid)
assert len(finished_run.data.metrics) == 0

Expand All @@ -190,6 +268,42 @@ def test_log_param(tracking_uri_mock):
assert expected_pairs[param.key] == param.value


def test_log_params(tracking_uri_mock):
expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
mlflow.log_params(expected_params)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate params
assert len(finished_run.data.params) == 3
for param in finished_run.data.params:
assert expected_params[param.key] == param.value


def test_log_batch_validates_entity_names_and_values(tracking_uri_mock):
active_run = start_run()
bad_kwargs = {
"metrics": [
[Metric(key="../bad/metric/name", value=0.3, timestamp=3)],
[Metric(key="ok-name", value="non-numerical-value", timestamp=3)],
[Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp")],
],
"params": [[Param(key="../bad/param/name", value="my-val")]],
"tags": [[Param(key="../bad/tag/name", value="my-val")]],
}
with active_run:
for kwarg, bad_values in bad_kwargs.items():
for bad_kwarg_value in bad_values:
final_kwargs = {
"run_id": active_run.info.run_uuid, "metrics": [], "params": [], "tags": [],
}
final_kwargs[kwarg] = bad_kwarg_value
with pytest.raises(MlflowException) as e:
tracking.MlflowClient().log_batch(**final_kwargs)
assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)


def test_log_artifact(tracking_uri_mock):
artifact_src_dir = tempfile.mkdtemp()
# Create artifacts
Expand Down
Loading

0 comments on commit af714c2

Please sign in to comment.