From e16a704b4265020c9ff3b33b80088409680d4f72 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 24 May 2024 01:48:40 +0000 Subject: [PATCH] [Identity] Add Azure Arc key validation checks Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 6 + .../azure/identity/_credentials/azure_arc.py | 56 +++++++- .../azure-identity/azure/identity/_version.py | 2 +- .../tests/test_managed_identity.py | 132 ++++++++++++++++-- .../tests/test_managed_identity_async.py | 123 +++++++++++++++- 5 files changed, 303 insertions(+), 16 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 6c549911b647..6a88fed5cd97 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +## 1.16.1 (2024-06-11) + +### Bugs Fixed + +- Managed identity bug fixes + ## 1.16.0 (2024-04-09) ### Other Changes diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index 68034300b819..859f625e158c 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -4,6 +4,7 @@ # ------------------------------------ import functools import os +import sys from typing import Any, Dict, Optional from azure.core.exceptions import ClientAuthenticationError @@ -24,7 +25,7 @@ def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: return ManagedIdentityClient( _per_retry_policies=[ArcChallengeAuthPolicy()], request_factory=functools.partial(_get_request, url), - **kwargs + **kwargs, ) return None @@ -70,6 +71,12 @@ def _get_secret_key(response: PipelineResponse) -> str: raise ClientAuthenticationError( message="Did not receive a correct value from WWW-Authenticate header: {}".format(header) ) from ex + + try: + _validate_key_file(key_file) + except ValueError as ex: + raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex + with open(key_file, "r", encoding="utf-8") as file: try: return file.read() @@ -80,6 +87,53 @@ def _get_secret_key(response: PipelineResponse) -> str: ) from error +def _get_key_file_path() -> str: + """Returns the expected path for the Azure Arc MSI key file based on the current platform. + + Only Linux and Windows are supported. + + :return: The expected path. + :rtype: str + :raises ValueError: If the current platform is not supported. + """ + if sys.platform.startswith("linux"): + return "/var/opt/azcmagent/tokens" + if sys.platform.startswith("win"): + program_data_path = os.environ.get("PROGRAMDATA") + if not program_data_path: + raise ValueError("PROGRAMDATA environment variable is not set or is empty.") + return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens") + raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}") + + +def _validate_key_file(file_path: str) -> None: + """Validates that a given Azure Arc MSI file path is valid for use. + + A valid file will: + 1. Be in the expected path for the current platform. + 2. Have a `.key` extension. + 3. Be at most 4096 bytes in size. + + :param str file_path: The path to the key file. + :raises ClientAuthenticationError: If the file path is invalid. + """ + if not file_path: + raise ValueError("The file path must not be empty.") + + if not os.path.exists(file_path): + raise ValueError(f"The file path does not exist: {file_path}") + + expected_directory = _get_key_file_path() + if not os.path.dirname(file_path) == expected_directory: + raise ValueError(f"Unexpected file path from HIMDS service: {file_path}") + + if not file_path.endswith(".key"): + raise ValueError("The file path must have a '.key' extension.") + + if os.path.getsize(file_path) > 4096: + raise ValueError("The file size must be less than or equal to 4096 bytes.") + + class ArcChallengeAuthPolicy(HTTPPolicy): """Policy for handling Azure Arc's challenge authentication""" diff --git a/sdk/identity/azure-identity/azure/identity/_version.py b/sdk/identity/azure-identity/azure/identity/_version.py index d6b2b365a466..b47d22c5d31a 100644 --- a/sdk/identity/azure-identity/azure/identity/_version.py +++ b/sdk/identity/azure-identity/azure/identity/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -VERSION = "1.16.0" +VERSION = "1.16.1" diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index 805e36343baa..b0be4b91ee5e 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import os +import sys import time try: @@ -883,9 +884,10 @@ def test_azure_arc(tmpdir): "os.environ", {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token.token == access_token - assert token.expires_on == expires_on + with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): + token = ManagedIdentityCredential(transport=transport).get_token(scope) + assert token.token == access_token + assert token.expires_on == expires_on def test_azure_arc_tenant_id(tmpdir): @@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir): "os.environ", {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token.token == access_token - assert token.expires_on == expires_on + with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): + token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") + assert token.token == access_token + assert token.expires_on == expires_on def test_azure_arc_client_id(): @@ -950,10 +953,123 @@ def test_azure_arc_client_id(): EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", }, ): - credential = ManagedIdentityCredential(client_id="some-guid") + with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): + credential = ManagedIdentityCredential(client_id="some-guid") - with pytest.raises(ClientAuthenticationError): + with pytest.raises(ClientAuthenticationError) as ex: credential.get_token("scope") + assert "not supported" in str(ex.value) + + +def test_azure_arc_key_too_large(tmp_path): + + api_version = "2019-11-01" + identity_endpoint = "http://localhost:42/token" + imds_endpoint = "http://localhost:42" + scope = "scope" + secret_key = "X" * 4097 + + key_file = tmp_path / "key_file.key" + key_file.write_text(secret_key) + assert key_file.read_text() == secret_key + + transport = validating_transport( + requests=[ + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + ], + responses=[ + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}), + ], + ) + + with mock.patch( + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, + ): + with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): + with pytest.raises(ClientAuthenticationError) as ex: + ManagedIdentityCredential(transport=transport).get_token(scope) + assert "file size" in str(ex.value) + + +def test_azure_arc_key_not_exist(tmp_path): + + api_version = "2019-11-01" + identity_endpoint = "http://localhost:42/token" + imds_endpoint = "http://localhost:42" + scope = "scope" + + transport = validating_transport( + requests=[ + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + ], + responses=[ + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}), + ], + ) + + with mock.patch( + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, + ): + with pytest.raises(ClientAuthenticationError) as ex: + ManagedIdentityCredential(transport=transport).get_token(scope) + assert "not exist" in str(ex.value) + + +def test_azure_arc_key_invalid(tmp_path): + + api_version = "2019-11-01" + identity_endpoint = "http://localhost:42/token" + imds_endpoint = "http://localhost:42" + scope = "scope" + key_file = tmp_path / "key_file.txt" + key_file.write_text("secret") + + transport = validating_transport( + requests=[ + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + ], + responses=[ + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}), + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}), + ], + ) + + with mock.patch( + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, + ): + with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"): + with pytest.raises(ClientAuthenticationError) as ex: + ManagedIdentityCredential(transport=transport).get_token(scope) + assert "Unexpected file path" in str(ex.value) + + with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): + with pytest.raises(ClientAuthenticationError) as ex: + ManagedIdentityCredential(transport=transport).get_token(scope) + assert "extension" in str(ex.value) def test_token_exchange(tmpdir): diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index f9c1981158a2..79d254848744 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -848,9 +848,10 @@ async def test_azure_arc(tmpdir): "os.environ", {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) - assert token.token == access_token - assert token.expires_on == expires_on + with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): + token = await ManagedIdentityCredential(transport=transport).get_token(scope) + assert token.token == access_token + assert token.expires_on == expires_on @pytest.mark.asyncio @@ -901,9 +902,10 @@ async def test_azure_arc_tenant_id(tmpdir): "os.environ", {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token.token == access_token - assert token.expires_on == expires_on + with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): + token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") + assert token.token == access_token + assert token.expires_on == expires_on @pytest.mark.asyncio @@ -922,6 +924,115 @@ async def test_azure_arc_client_id(): await credential.get_token("scope") +@pytest.mark.asyncio +async def test_azure_arc_key_too_large(tmp_path): + api_version = "2019-11-01" + identity_endpoint = "http://localhost:42/token" + imds_endpoint = "http://localhost:42" + scope = "scope" + secret_key = "X" * 4097 + + key_file = tmp_path / "key_file.key" + key_file.write_text(secret_key) + assert key_file.read_text() == secret_key + + transport = async_validating_transport( + requests=[ + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + ], + responses=[ + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}), + ], + ) + with mock.patch( + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, + ): + with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): + with pytest.raises(ClientAuthenticationError) as ex: + await ManagedIdentityCredential(transport=transport).get_token(scope) + assert "file size" in str(ex.value) + + +@pytest.mark.asyncio +async def test_azure_arc_key_not_exist(tmp_path): + api_version = "2019-11-01" + identity_endpoint = "http://localhost:42/token" + imds_endpoint = "http://localhost:42" + scope = "scope" + + transport = async_validating_transport( + requests=[ + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + ], + responses=[ + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=key_file"}), + ], + ) + with mock.patch( + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, + ): + with pytest.raises(ClientAuthenticationError) as ex: + await ManagedIdentityCredential(transport=transport).get_token(scope) + assert "not exist" in str(ex.value) + + +@pytest.mark.asyncio +async def test_azure_arc_key_invalid(tmp_path): + api_version = "2019-11-01" + identity_endpoint = "http://localhost:42/token" + imds_endpoint = "http://localhost:42" + scope = "scope" + key_file = tmp_path / "key_file.txt" + key_file.write_text("secret") + + transport = async_validating_transport( + requests=[ + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + Request( + base_url=identity_endpoint, + method="GET", + required_headers={"Metadata": "true"}, + required_params={"api-version": api_version, "resource": scope}, + ), + ], + responses=[ + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}), + mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}), + ], + ) + + with mock.patch( + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, + ): + with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"): + with pytest.raises(ClientAuthenticationError) as ex: + await ManagedIdentityCredential(transport=transport).get_token(scope) + assert "Unexpected file path" in str(ex.value) + + with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): + with pytest.raises(ClientAuthenticationError) as ex: + await ManagedIdentityCredential(transport=transport).get_token(scope) + assert "extension" in str(ex.value) + + @pytest.mark.asyncio async def test_token_exchange(tmpdir): exchange_token = "exchange-token"