Skip to content

Commit

Permalink
fix: fix patching in unit tests (#264)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Zhu <daniezh@amazon.com>
  • Loading branch information
danielezhu and Daniel Zhu committed Apr 25, 2024
1 parent 5ac8466 commit c9c0e59
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions test/unit/data_loaders/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_open_invalid_local_data_file(self, invalid_file_path):
class TestS3DataFile:
@pytest.mark.parametrize("file_path", FILE_PATHS)
def test_open_s3_data_file(self, file_path):
with patch("src.fmeval.data_loaders.data_sources.boto3.client") as mock_boto3_client:
with patch("fmeval.data_loaders.data_sources.boto3.client") as mock_boto3_client:
mock_s3_client = Mock()
mock_boto3_client.return_value = mock_s3_client
with io.StringIO() as buf:
Expand All @@ -49,7 +49,7 @@ def test_open_s3_data_file(self, file_path):

@pytest.mark.parametrize("invalid_file_path", INVALID_FILE_PATHS)
def test_open_invalid_s3_data_file(self, invalid_file_path):
with patch("src.fmeval.data_loaders.data_sources.boto3.client") as mock_boto3_client:
with patch("fmeval.data_loaders.data_sources.boto3.client") as mock_boto3_client:
mock_s3_client = Mock()
mock_s3_client.get_object.side_effect = botocore.errorfactory.ClientError({"error": "blah"}, "blah")
mock_boto3_client.return_value = mock_s3_client
Expand Down
12 changes: 6 additions & 6 deletions test/unit/data_loaders/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_get_data_source_provides_s3_data_source(self):
WHEN get_data_source is called
THEN an S3DataFile with the correct URI is returned
"""
with patch("src.fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
with patch("fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
mock_s3_client.get_object = Mock(return_value={"ContentType": "binary/octet-stream"})
dataset_uri = S3_PREFIX + DATASET_URI
data_source = get_data_source(dataset_uri)
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_get_data_sources_s3_directory_exception(self):
WHEN get_data_source is called
THEN the correct exception is raised
"""
with patch("src.fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
with patch("fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
mock_s3_client.return_value.get_object = Mock(
return_value={"ContentType": "application/x-directory; charset=UTF-8"}
)
Expand All @@ -246,8 +246,8 @@ def test_get_data_source_invalid_local_path(self):

def test_get_data_source_invalid_dataset_path(self):
with (
patch("src.fmeval.data_loaders.util._is_valid_local_path", return_value=False),
patch("src.fmeval.data_loaders.util._is_valid_s3_uri", return_value=False),
patch("fmeval.data_loaders.util._is_valid_local_path", return_value=False),
patch("fmeval.data_loaders.util._is_valid_s3_uri", return_value=False),
):
with pytest.raises(EvalAlgorithmClientError, match="Invalid dataset path"):
get_data_source("unused")
Expand All @@ -259,7 +259,7 @@ def test_is_valid_s3_uri_success(self):
THEN True is returned
"""
dataset_uri = S3_PREFIX + DATASET_URI
with patch("src.fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
with patch("fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
mock_s3_client.return_value.get_object.return_value = Mock()
assert _is_valid_s3_uri(dataset_uri)

Expand All @@ -279,7 +279,7 @@ def test_is_valid_s3_uri_client_error(self):
WHEN _is_valid_s3_uri is called
THEN False is returned
"""
with patch("src.fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
with patch("fmeval.data_loaders.data_sources.boto3.client") as mock_s3_client:
mock_s3_client.return_value.get_object = Mock(
side_effect=botocore.errorfactory.ClientError({"error": "blah"}, "blah")
)
Expand Down
4 changes: 2 additions & 2 deletions test/unit/model_runners/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


def test_get_user_agent_extra():
with patch("src.fmeval.model_runners.util.get_fmeval_package_version", return_value="1.0.1") as get_package_ver:
assert get_user_agent_extra().endswith("fmeval/1.0.1")
with patch("fmeval.model_runners.util.get_fmeval_package_version", return_value="1.0.0") as get_package_ver:
assert get_user_agent_extra().endswith("fmeval/1.0.0")
os.environ[DISABLE_FMEVAL_TELEMETRY] = "True"
assert "fmeval" not in get_user_agent_extra()
del os.environ[DISABLE_FMEVAL_TELEMETRY]
Expand Down

0 comments on commit c9c0e59

Please sign in to comment.