From 2afbe32b0dbbcdb272a085d95ac5fdc609493d04 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 May 2024 21:06:23 +0900 Subject: [PATCH] Add coverage and improve performance of is_ssh_key (#940) * Add coverage and improve performance of is_ssh_key * simplify --- jwt/utils.py | 17 +++-------------- tests/test_utils.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/jwt/utils.py b/jwt/utils.py index 81c5ee41..d469139b 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -131,26 +131,15 @@ def is_pem_format(key: bytes) -> bool: # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 -_CERT_SUFFIX = b"-cert-v01@openssh.com" -_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") -_SSH_KEY_FORMATS = [ +_SSH_KEY_FORMATS = ( b"ssh-ed25519", b"ssh-rsa", b"ssh-dss", b"ecdsa-sha2-nistp256", b"ecdsa-sha2-nistp384", b"ecdsa-sha2-nistp521", -] +) def is_ssh_key(key: bytes) -> bool: - if any(string_value in key for string_value in _SSH_KEY_FORMATS): - return True - - ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) - if ssh_pubkey_match: - key_type = ssh_pubkey_match.group(1) - if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: - return True - - return False + return key.startswith(_SSH_KEY_FORMATS) diff --git a/tests/test_utils.py b/tests/test_utils.py index 122dcb4e..d8d0d6c4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import pytest -from jwt.utils import force_bytes, from_base64url_uint, to_base64url_uint +from jwt.utils import force_bytes, from_base64url_uint, is_ssh_key, to_base64url_uint @pytest.mark.parametrize( @@ -37,3 +37,19 @@ def test_from_base64url_uint(inputval, expected): def test_force_bytes_raises_error_on_invalid_object(): with pytest.raises(TypeError): force_bytes({}) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "key_format", + ( + b"ssh-ed25519", + b"ssh-rsa", + b"ssh-dss", + b"ecdsa-sha2-nistp256", + b"ecdsa-sha2-nistp384", + b"ecdsa-sha2-nistp521", + ), +) +def test_is_ssh_key(key_format): + assert is_ssh_key(key_format + b" any") is True + assert is_ssh_key(b"not a ssh key") is False