Skip to content

Commit

Permalink
Azure artifact store: fix path resolution error when artifact root is…
Browse files Browse the repository at this point in the history
… container root (mlflow#769)

* Fix relative path resolution in azure blob artifact repo

* Add test

* Naming change

* Validate that blob names begin with artifact path

* Add negative test case

* Lint

* test fix
  • Loading branch information
dbczumar committed Dec 21, 2018
1 parent 6a13647 commit 4a2ba72
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
13 changes: 11 additions & 2 deletions mlflow/store/azure_blob_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,22 @@ def list_artifacts(self, path=None):
while True:
results = self.client.list_blobs(container, prefix=prefix, delimiter='/', marker=marker)
for r in results:
if not r.name.startswith(artifact_path):
raise ValueError(
"The name of the listed Azure blob does not begin with the specified"
" artifact path. Artifact path: {artifact_path}. Blob name:"
" {blob_name}".format(artifact_path=artifact_path, blob_name=r.name))
if isinstance(r, BlobPrefix): # This is a prefix for items in a subdirectory
subdir = r.name[len(artifact_path)+1:]
# Separator needs to be fixed as '/' because of azure blob storage pattern.
# Do not change to os.relpath because in Windows system path separator is '\'
subdir = posixpath.relpath(path=r.name, start=artifact_path)
if subdir.endswith("/"):
subdir = subdir[:-1]
infos.append(FileInfo(subdir, True, None))
else: # Just a plain old blob
file_name = r.name[len(artifact_path)+1:]
# Separator needs to be fixed as '/' because of azure blob storage pattern.
# Do not change to os.relpath because in Windows system path separator is '\'
file_name = posixpath.relpath(path=r.name, start=artifact_path)
infos.append(FileInfo(file_name, False, r.properties.content_length))
# Check whether a new marker is returned, meaning we have to make another request
if results.next_marker:
Expand Down
91 changes: 89 additions & 2 deletions tests/store/test_azure_blob_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


TEST_ROOT_PATH = "some/path"
TEST_URI = "wasbs://container@account.blob.core.windows.net/" + TEST_ROOT_PATH
TEST_BLOB_CONTAINER_ROOT = "wasbs://container@account.blob.core.windows.net/"
TEST_URI = os.path.join(TEST_BLOB_CONTAINER_ROOT, TEST_ROOT_PATH)


class MockBlobList(object):
Expand Down Expand Up @@ -155,7 +156,9 @@ def create_file(container, cloud_path, local_path):
"container", TEST_ROOT_PATH + "/test.txt", mock.ANY)


def test_download_directory_artifact(mock_client, tmpdir):
def test_download_directory_artifact_succeeds_when_artifact_root_is_not_blob_container_root(
mock_client, tmpdir):
assert TEST_URI is not TEST_BLOB_CONTAINER_ROOT
repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

file_path_1 = "file_1"
Expand Down Expand Up @@ -198,3 +201,87 @@ def create_file(container, cloud_path, local_path):
dir_contents = os.listdir(tmpdir.strpath)
assert file_path_1 in dir_contents
assert file_path_2 in dir_contents


def test_download_directory_artifact_succeeds_when_artifact_root_is_blob_container_root(
mock_client, tmpdir):
repo = AzureBlobArtifactRepository(TEST_BLOB_CONTAINER_ROOT, mock_client)

subdir_path = "my_directory"
dir_prefix = BlobPrefix()
dir_prefix.name = subdir_path

file_path_1 = "file_1"
file_path_2 = "file_2"

blob_props_1 = BlobProperties()
blob_props_1.content_length = 42
blob_1 = Blob(os.path.join(subdir_path, file_path_1), props=blob_props_1)

blob_props_2 = BlobProperties()
blob_props_2.content_length = 42
blob_2 = Blob(os.path.join(subdir_path, file_path_2), props=blob_props_2)

def get_mock_listing(*args, **kwargs):
"""
Produces a mock listing that only contains content if the specified prefix is the artifact
root or a relevant subdirectory. This allows us to mock `list_artifacts` during the
`_download_artifacts_into` subroutine without recursively listing the same artifacts at
every level of the directory traversal.
"""
# pylint: disable=unused-argument
if os.path.abspath(kwargs["prefix"]) == "/":
return MockBlobList([dir_prefix])
if os.path.abspath(kwargs["prefix"]) == os.path.abspath(subdir_path):
return MockBlobList([blob_1, blob_2])
else:
return MockBlobList([])

def create_file(container, cloud_path, local_path):
# pylint: disable=unused-argument
fname = os.path.basename(local_path)
f = tmpdir.join(fname)
f.write("hello world!")

mock_client.list_blobs.side_effect = get_mock_listing
mock_client.get_blob_to_path.side_effect = create_file

# Ensure that the root directory can be downloaded successfully
repo.download_artifacts("")
# Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
dir_contents = os.listdir(tmpdir.strpath)
assert file_path_1 in dir_contents
assert file_path_2 in dir_contents


def test_download_artifact_throws_value_error_when_listed_blobs_do_not_contain_artifact_root_prefix(
mock_client):
repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

# Create a "bad blob" with a name that is not prefixed by the root path of the artifact store
bad_blob_props = BlobProperties()
bad_blob_props.content_length = 42
bad_blob = Blob("file_path", props=bad_blob_props)

def get_mock_listing(*args, **kwargs):
"""
Produces a mock listing that only contains content if the
specified prefix is the artifact root. This allows us to mock
`list_artifacts` during the `_download_artifacts_into` subroutine
without recursively listing the same artifacts at every level of the
directory traversal.
"""
# pylint: disable=unused-argument
if os.path.abspath(kwargs["prefix"]) == os.path.abspath(TEST_ROOT_PATH):
# Return a blob that is not prefixed by the root path of the artifact store. This
# should result in an exception being raised
return MockBlobList([bad_blob])
else:
return MockBlobList([])

mock_client.list_blobs.side_effect = get_mock_listing

with pytest.raises(ValueError) as exc:
repo.download_artifacts("")

assert "Azure blob does not begin with the specified artifact path" in str(exc)

0 comments on commit 4a2ba72

Please sign in to comment.