Skip to content

Commit

Permalink
Adding GCS artifact storage capabilities. (mlflow#152)
Browse files Browse the repository at this point in the history
* Adding GCS artifact storage capabilities.

Add google-cloud-storage as a dependency.

Fixing a couple bugs with the GCS store.

* Fix pep8 issues.

* Add tests for GCSArtifactRepo.

* Trigger

* pep8 fix.

* Use super instead of setting artifact_uri

* Removing an unecessary lambda.

* Adding GCS information to the storage docs.

* Store the old GOOGLE_APPLICATION_CREDENTIALS environment variable, then restore after gcs tests.

* Convert from unittest to pytest.

* Verifying call signature of gcs client calls.

* Adding tests for log_artifacts and _download_artifacts.

* Ignore redifined-outer-name in pytest fixtures.
  • Loading branch information
bnekolny authored and smurching committed Jul 17, 2018
1 parent 39db66f commit 0ba50e7
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/source/tracking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ For the clients and server to access the artifact location, you should configure
provider credentials as normal. For example, for S3, you can set the ``AWS_ACCESS_KEY_ID``
and ``AWS_SECRET_ACCESS_KEY`` environment variables, use an IAM role, or configure a default
profile in ``~/.aws/credentials``. See `Set up AWS Credentials and Region for Development <https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/setup-credentials.html>`_ for more info.
To utilize Google Cloud Storage you can set the artifact-root to ``gs://<storage_bucket>``, and you will need to provide auth as per
the documentation for `Authentication <https://google-cloud.readthedocs.io/en/latest/core/auth.html>`_.


Networking
^^^^^^^^^^
Expand Down
89 changes: 89 additions & 0 deletions mlflow/store/artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os

import boto3
from google.cloud import storage as gcs_storage

from mlflow.utils.file_utils import (mkdir, exists, list_all, get_relative_path,
get_file_info, build_path, TempDir)
Expand Down Expand Up @@ -75,6 +76,8 @@ def from_artifact_uri(artifact_uri):
"""
if artifact_uri.startswith("s3:/"):
return S3ArtifactRepository(artifact_uri)
elif artifact_uri.startswith("gs:/"):
return GCSArtifactRepository(artifact_uri)
else:
return LocalArtifactRepository(artifact_uri)

Expand Down Expand Up @@ -186,3 +189,89 @@ def _download_artifacts_into(self, artifact_path, dest_dir):
s3_path = build_path(s3_path, artifact_path)
boto3.client('s3').download_file(bucket, s3_path, local_path)
return local_path


class GCSArtifactRepository(ArtifactRepository):
"""Stores artifacts on Google Cloud Storage.
Assumes the google credentials are available in the environment,
see https://google-cloud.readthedocs.io/en/latest/core/auth.html """

def __init__(self, artifact_uri, client=gcs_storage):
self.gcs = client
super(GCSArtifactRepository, self).__init__(artifact_uri)

@staticmethod
def parse_gcs_uri(uri):
"""Parse an GCS URI, returning (bucket, path)"""
parsed = urllib.parse.urlparse(uri)
if parsed.scheme != "gs":
raise Exception("Not a GCS URI: %s" % uri)
path = parsed.path
if path.startswith('/'):
path = path[1:]
return parsed.netloc, path

def log_artifact(self, local_file, artifact_path=None):
(bucket, dest_path) = self.parse_gcs_uri(self.artifact_uri)
if artifact_path:
dest_path = build_path(dest_path, artifact_path)
dest_path = build_path(dest_path, os.path.basename(local_file))

gcs_bucket = self.gcs.Client().get_bucket(bucket)
blob = gcs_bucket.blob(dest_path)
blob.upload_from_filename(local_file)

def log_artifacts(self, local_dir, artifact_path=None):
(bucket, dest_path) = self.parse_gcs_uri(self.artifact_uri)
if artifact_path:
dest_path = build_path(dest_path, artifact_path)
gcs_bucket = self.gcs.Client().get_bucket(bucket)

local_dir = os.path.abspath(local_dir)
for (root, _, filenames) in os.walk(local_dir):
upload_path = dest_path
if root != local_dir:
rel_path = get_relative_path(local_dir, root)
upload_path = build_path(dest_path, rel_path)
for f in filenames:
path = build_path(upload_path, f)
gcs_bucket.blob(path).upload_from_filename(build_path(root, f))

def list_artifacts(self, path=None):
(bucket, artifact_path) = self.parse_gcs_uri(self.artifact_uri)
dest_path = artifact_path
if path:
dest_path = build_path(dest_path, path)
infos = []
prefix = dest_path + "/"

results = self.gcs.Client().get_bucket(bucket).list_blobs(prefix=prefix)
for result in results:
is_dir = result.name.endswith('/')
if is_dir:
blob_path = path[:-1]
else:
blob_path = result.name[len(artifact_path)+1:]
infos.append(FileInfo(blob_path, is_dir, result.size))
return sorted(infos, key=lambda f: f.path)

def download_artifacts(self, artifact_path):
with TempDir(remove_on_exit=False) as tmp:
return self._download_artifacts_into(artifact_path, tmp.path())

def _download_artifacts_into(self, artifact_path, dest_dir):
"""Private version of download_artifacts that takes a destination directory."""
basename = os.path.basename(artifact_path)
local_path = build_path(dest_dir, basename)
listing = self.list_artifacts(artifact_path)
if len(listing) > 0:
# Artifact_path is a directory, so make a directory for it and download everything
os.mkdir(local_path)
for file_info in listing:
self._download_artifacts_into(file_info.path, local_path)
else:
(bucket, remote_path) = self.parse_gcs_uri(self.artifact_uri)
remote_path = build_path(remote_path, artifact_path)
gcs_bucket = self.gcs.Client().get_bucket(bucket)
gcs_bucket.get_blob(remote_path).download_to_filename(local_path)
return local_path
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def package_files(directory):
'gitpython>=2.1.0',
'pyyaml',
'boto3',
'google-cloud-storage',
'querystring_parser',
],
entry_points='''
Expand Down
109 changes: 109 additions & 0 deletions tests/store/test_gcs_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# pylint: disable=redefined-outer-name
import os
import mock
import pytest

from google.cloud.storage import client as gcs_client

from mlflow.store.artifact_repo import ArtifactRepository, GCSArtifactRepository


@pytest.fixture
def gcs_mock():
# Make sure that the environment variable isn't set to actually make calls
old_G_APP_CREDS = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/dev/null'

yield mock.MagicMock(autospec=gcs_client)

if old_G_APP_CREDS:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = old_G_APP_CREDS


def test_artifact_uri_factory():
repo = ArtifactRepository.from_artifact_uri("gs://test_bucket/some/path")
assert isinstance(repo, GCSArtifactRepository)


def test_list_artifacts_empty(gcs_mock):
repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)
gcs_mock.Client.return_value.get_bucket.return_value\
.list_blobs.return_value = []
assert repo.list_artifacts() == []


def test_list_artifacts(gcs_mock):
repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)
mockobj = mock.Mock()
mockobj.configure_mock(
name='/some/path/mockeryname',
f='/mockeryname',
size=1,
)
gcs_mock.Client.return_value.get_bucket.return_value\
.list_blobs.return_value = [mockobj]
assert repo.list_artifacts()[0].path == mockobj.f
assert repo.list_artifacts()[0].file_size == mockobj.size


def test_log_artifact(gcs_mock, tmpdir):
repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

d = tmpdir.mkdir("data")
f = d.join("test.txt")
f.write("hello world!")
fpath = d + '/test.txt'
fpath = fpath.strpath

# This will call isfile on the code path being used,
# thus testing that it's being called with an actually file path
gcs_mock.Client.return_value.get_bucket.return_value.blob.return_value\
.upload_from_filename.side_effect = os.path.isfile
repo.log_artifact(fpath)

gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
gcs_mock.Client().get_bucket().blob\
.assert_called_with('some/path/test.txt')
gcs_mock.Client().get_bucket().blob().upload_from_filename\
.assert_called_with(fpath)


def test_log_artifacts(gcs_mock, tmpdir):
repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

subd = tmpdir.mkdir("data").mkdir("subdir")
subd.join("a.txt").write("A")
subd.join("b.txt").write("B")
subd.join("c.txt").write("C")

gcs_mock.Client.return_value.get_bucket.return_value.blob.return_value\
.upload_from_filename.side_effect = os.path.isfile
repo.log_artifacts(subd.strpath)

gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
gcs_mock.Client().get_bucket().blob().upload_from_filename\
.assert_has_calls([
mock.call('%s/a.txt' % subd.strpath),
mock.call('%s/b.txt' % subd.strpath),
mock.call('%s/c.txt' % subd.strpath),
], any_order=True)


def test_download_artifacts(gcs_mock, tmpdir):
repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

def mkfile(fname):
fname = fname.replace(tmpdir.strpath, '')
f = tmpdir.join(fname)
f.write("hello world!")
return f.strpath

gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\
.download_to_filename.side_effect = mkfile

open(repo._download_artifacts_into("test.txt", tmpdir.strpath)).read()
gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
gcs_mock.Client().get_bucket().get_blob\
.assert_called_with('some/path/test.txt')
gcs_mock.Client().get_bucket().get_blob()\
.download_to_filename.assert_called_with(tmpdir + "/test.txt")

0 comments on commit 0ba50e7

Please sign in to comment.