forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prevent an infinite loop in download_artifacts (mlflow#1605)
With some data stores, ArtifactRepository.list_artifacts returns the current path in addition to its children (reported in mlflow#1458). This PR fixes ArtifactRepository.download_artifacts which used to go into an infinite loop in such cases.
- Loading branch information
Showing
2 changed files
with
47 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import mock | ||
import pytest | ||
|
||
from mlflow.entities import FileInfo | ||
from mlflow.store.artifact_repo import ArtifactRepository | ||
|
||
|
||
class ArtifactRepositoryImpl(ArtifactRepository): | ||
|
||
def log_artifact(self, local_file, artifact_path=None): | ||
raise NotImplementedError() | ||
|
||
def log_artifacts(self, local_dir, artifact_path=None): | ||
raise NotImplementedError() | ||
|
||
def list_artifacts(self, path): | ||
raise NotImplementedError() | ||
|
||
def _download_file(self, remote_file_path, local_path): | ||
print("download_file called with ", remote_file_path) | ||
|
||
|
||
@pytest.mark.parametrize("base_uri, download_arg, list_return_val, expected_args", [ | ||
('12345/model', '', ['modelfile'], ['modelfile']), | ||
('12345/model', '', ['.', 'modelfile'], ['modelfile']), | ||
('12345', 'model', ['model/modelfile'], ['model/modelfile']), | ||
('12345', 'model', ['model', 'model/modelfile'], ['model/modelfile']), | ||
('', '12345/model', ['12345/model/modelfile'], ['12345/model/modelfile']), | ||
('', '12345/model', ['12345/model', '12345/model/modelfile'], ['12345/model/modelfile']), | ||
]) | ||
def test_download_artifacts_does_not_infinitely_loop(base_uri, download_arg, list_return_val, | ||
expected_args): | ||
|
||
def list_artifacts_mock(self, path): | ||
if path.endswith("model"): | ||
return [FileInfo(item, False, 123) for item in list_return_val] | ||
else: | ||
return [] | ||
|
||
with mock.patch.object(ArtifactRepositoryImpl, "list_artifacts", | ||
new_callable=lambda: list_artifacts_mock): | ||
repo = ArtifactRepositoryImpl(base_uri) | ||
repo.download_artifacts(download_arg) |