Skip to content

Commit

Permalink
Prevent an infinite loop in download_artifacts (mlflow#1605)
Browse files Browse the repository at this point in the history
With some data stores, ArtifactRepository.list_artifacts returns the current path in addition to its children (reported in mlflow#1458). This PR fixes ArtifactRepository.download_artifacts which used to go into an infinite loop in such cases.
  • Loading branch information
sueann committed Jul 19, 2019
1 parent 722e897 commit 3d39607
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mlflow/store/artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def download_artifacts(self, artifact_path, dst_path=None):
"""

# TODO: Probably need to add a more efficient method to stream just a single artifact
# without downloading it, or to get a pre-signed URL for cloud storage.
# without downloading it, or to get a pre-signed URL for cloud storage.

def download_artifacts_into(artifact_path, dest_dir):
basename = posixpath.basename(artifact_path)
Expand All @@ -85,6 +85,9 @@ def download_artifacts_into(artifact_path, dest_dir):
if not os.path.exists(local_path):
os.mkdir(local_path)
for file_info in listing:
# prevent an infinite loop (sometimes the current path is listed e.g. as ".")
if file_info.path == "." or file_info.path == artifact_path:
continue
download_artifacts_into(artifact_path=file_info.path, dest_dir=local_path)
else:
self._download_file(remote_file_path=artifact_path, local_path=local_path)
Expand Down
43 changes: 43 additions & 0 deletions tests/store/test_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import mock
import pytest

from mlflow.entities import FileInfo
from mlflow.store.artifact_repo import ArtifactRepository


class ArtifactRepositoryImpl(ArtifactRepository):

def log_artifact(self, local_file, artifact_path=None):
raise NotImplementedError()

def log_artifacts(self, local_dir, artifact_path=None):
raise NotImplementedError()

def list_artifacts(self, path):
raise NotImplementedError()

def _download_file(self, remote_file_path, local_path):
print("download_file called with ", remote_file_path)


@pytest.mark.parametrize("base_uri, download_arg, list_return_val, expected_args", [
('12345/model', '', ['modelfile'], ['modelfile']),
('12345/model', '', ['.', 'modelfile'], ['modelfile']),
('12345', 'model', ['model/modelfile'], ['model/modelfile']),
('12345', 'model', ['model', 'model/modelfile'], ['model/modelfile']),
('', '12345/model', ['12345/model/modelfile'], ['12345/model/modelfile']),
('', '12345/model', ['12345/model', '12345/model/modelfile'], ['12345/model/modelfile']),
])
def test_download_artifacts_does_not_infinitely_loop(base_uri, download_arg, list_return_val,
expected_args):

def list_artifacts_mock(self, path):
if path.endswith("model"):
return [FileInfo(item, False, 123) for item in list_return_val]
else:
return []

with mock.patch.object(ArtifactRepositoryImpl, "list_artifacts",
new_callable=lambda: list_artifacts_mock):
repo = ArtifactRepositoryImpl(base_uri)
repo.download_artifacts(download_arg)

0 comments on commit 3d39607

Please sign in to comment.