Skip to content

Commit

Permalink
mlflow.set_experiment to activate an experiment before starting runs (m…
Browse files Browse the repository at this point in the history
…lflow#462)

* mlflow.set_experiment to activate an experiment before starting runs

* documentation for set_experiment

* addressing comments

* addressing review comments

* lint
  • Loading branch information
mparkhe committed Sep 19, 2018
1 parent b332ef6 commit 076a39b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 19 deletions.
4 changes: 4 additions & 0 deletions docs/source/tracking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ directory. The URI defaults to ``mlruns``.
:py:func:`mlflow.create_experiment` creates a new experiment and returns its ID. Runs can be
launched under the experiment by passing the experiment ID to ``mlflow.start_run``.

:py:func:`mlflow.set_experiment` sets an experiment as active. If the experiment does not exist,
creates a new experiment. If you do not specify an experiment in :py:func:`mlflow.start_run`, new
runs are launched under this experiment.

:py:func:`mlflow.start_run` returns the currently active run (if one exists), or starts a new run
and returns a :py:class:`mlflow.ActiveRun` object usable as a context manager for the
current run. You do not need to call ``start_run`` explicitly: calling one of the logging functions
Expand Down
3 changes: 2 additions & 1 deletion mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@
set_tracking_uri = tracking.set_tracking_uri
get_tracking_uri = tracking.get_tracking_uri
create_experiment = mlflow.tracking.fluent.create_experiment
set_experiment = mlflow.tracking.fluent.set_experiment


run = projects.run


__all__ = ["ActiveRun", "log_param", "log_metric", "set_tag", "log_artifacts", "log_artifact",
"active_run", "start_run", "end_run", "get_artifact_uri", "set_tracking_uri",
"create_experiment", "run"]
"create_experiment", "set_experiment", "run"]
46 changes: 35 additions & 11 deletions mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,50 @@

from __future__ import print_function

import atexit
import numbers
import os

import atexit
import sys
import time

from mlflow.entities import Experiment, Run, SourceType
from mlflow.tracking.client import MlflowClient
from mlflow.utils import env
from mlflow.utils.databricks_utils import is_in_databricks_notebook, get_notebook_id, \
get_notebook_path, get_webapp_url
from mlflow.utils.logging_utils import eprint
from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_WEBAPP_URL, \
MLFLOW_DATABRICKS_NOTEBOOK_PATH, \
MLFLOW_DATABRICKS_NOTEBOOK_ID
from mlflow.utils.validation import _validate_run_id
from mlflow.tracking.client import MlflowClient


_EXPERIMENT_ID_ENV_VAR = "MLFLOW_EXPERIMENT_ID"
_RUN_ID_ENV_VAR = "MLFLOW_RUN_ID"
_active_run = None
_active_experiment_id = None


def set_experiment(experiment_name):
"""
Set given experiment as active experiment. If experiment does not exist, create an experiment
with provided name.
:param experiment_name: Name of experiment to be activated.
"""
client = MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
exp_id = experiment.experiment_id if experiment else None
if not exp_id:
print("INFO: '{}' does not exist. Creating a new experiment".format(experiment_name))
exp_id = client.create_experiment(experiment_name)
global _active_experiment_id
_active_experiment_id = exp_id


class ActiveRun(Run): # pylint: disable=W0223
"""Wrapper around :py:class:`mlflow.entities.Run` to enable using Python ``with`` syntax."""

def __init__(self, run):
Run.__init__(self, run.info, run.data)

Expand Down Expand Up @@ -56,9 +76,11 @@ def start_run(run_uuid=None, experiment_id=None, source_name=None, source_versio
and metrics under that run. The run's end time is unset and its status
is set to running, but the run's other attributes (``source_version``,
``source_type``, etc.) are not changed.
:param experiment_id: ID of the experiment under which to create the current run.
Used only when ``run_uuid`` is unspecified. If unspecified,
the run is created under the default experiment.
:param experiment_id: ID of the experiment under which to create the current run (applicable
only when ``run_uuid`` is not specified). If ``experiment_id`` argument
is unspecified, will look for valid experiment in the following order:
activated using ``set_experiment``, ``MLFLOW_EXPERIMENT_ID`` env variable,
or the default experiment.
:param source_name: Name of the source file or URI of the project to be associated with the run.
If none provided defaults to the current file.
:param source_version: Optional Git commit hash to associate with the run.
Expand Down Expand Up @@ -159,8 +181,8 @@ def log_metric(key, value):
:param value: Metric value (float).
"""
if not isinstance(value, numbers.Number):
print("WARNING: The metric {}={} was not logged because the value is not a number.".format(
key, value), file=sys.stderr)
eprint("WARNING: The metric {}={} was not logged because the value is not a number.".format(
key, value))
return
run_id = _get_or_start_run().info.run_uuid
MlflowClient().log_metric(run_id, key, value, int(time.time()))
Expand Down Expand Up @@ -239,15 +261,17 @@ def _get_source_type():


def _get_experiment_id():
return int(env.get_env(_EXPERIMENT_ID_ENV_VAR) or Experiment.DEFAULT_EXPERIMENT_ID)
return int(_active_experiment_id or
env.get_env(_EXPERIMENT_ID_ENV_VAR) or
Experiment.DEFAULT_EXPERIMENT_ID)


def _get_git_commit(path):
try:
from git import Repo, InvalidGitRepositoryError, GitCommandNotFound, NoSuchPathError
except ImportError as e:
print("Notice: failed to import Git (the Git executable is probably not on your PATH),"
" so Git SHA is not available. Error: %s" % e, file=sys.stderr)
eprint("Notice: failed to import Git (the Git executable is probably not on your PATH),"
" so Git SHA is not available. Error: %s" % e)
return None
try:
if os.path.isfile(path):
Expand Down
42 changes: 35 additions & 7 deletions tests/tracking/test_tracking.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import random
from contextlib import contextmanager
import filecmp
import os
import random
import tempfile
import shutil

import mock
import pytest

from mlflow.store.file_store import FileStore
from mlflow.entities import RunStatus
import mlflow
from mlflow import tracking
from mlflow.entities import RunStatus
from mlflow.tracking.fluent import start_run, end_run
import mlflow
from mlflow.utils.file_utils import TempDir


def test_create_experiment():
Expand All @@ -34,6 +31,37 @@ def test_create_experiment():
tracking.set_tracking_uri(None)


def test_set_experiment():
with pytest.raises(TypeError):
mlflow.set_experiment()

with pytest.raises(Exception):
mlflow.set_experiment(None)

with pytest.raises(Exception):
mlflow.set_experiment("")

try:
with TempDir() as tracking_uri:
tracking.set_tracking_uri(tracking_uri.path())
name = "random_exp"
exp_id = mlflow.create_experiment(name)
mlflow.set_experiment(name)
run = start_run()
assert run.info.experiment_id == exp_id
end_run()

another_name = "another_experiment"
mlflow.set_experiment(another_name)
exp_id2 = mlflow.tracking.MlflowClient().get_experiment_by_name(another_name)
another_run = start_run()
assert another_run.info.experiment_id == exp_id2.experiment_id
end_run()
finally:
# Need to do this to clear active experiment to restore state
mlflow.tracking.fluent._active_experiment_id = None


def test_no_nested_run():
try:
tracking.set_tracking_uri(tempfile.mkdtemp())
Expand Down

0 comments on commit 076a39b

Please sign in to comment.