Skip to content

Commit

Permalink
Fix Chacha20Cipher with companion
Browse files Browse the repository at this point in the history
followup to #2338

add tests
  • Loading branch information
bdraco authored and postlund committed Jun 24, 2024
1 parent fb78faf commit ea9d333
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 19 deletions.
6 changes: 3 additions & 3 deletions pyatv/auth/hap_srp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def verify1(self, credentials, session_pub_key, encrypted):
"Pair-Verify-Encrypt-Salt", "Pair-Verify-Encrypt-Info", self._shared
)

chacha = chacha20.Chacha20Cipher(session_key, session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(session_key, session_key)
decrypted_tlv = read_tlv(chacha.decrypt(encrypted, nonce="PV-Msg02".encode()))

identifier = decrypted_tlv[TlvValue.Identifier]
Expand Down Expand Up @@ -199,14 +199,14 @@ def step3(
if additional_data:
tlv.update(additional_data)

chacha = chacha20.Chacha20Cipher(self._session_key, self._session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(self._session_key, self._session_key)
encrypted_data = chacha.encrypt(write_tlv(tlv), nonce="PS-Msg05".encode())
log_binary(_LOGGER, "Data", Encrypted=encrypted_data)
return encrypted_data

def step4(self, encrypted_data):
"""Last pairing step."""
chacha = chacha20.Chacha20Cipher(self._session_key, self._session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(self._session_key, self._session_key)
decrypted_tlv_bytes = chacha.decrypt(encrypted_data, nonce="PS-Msg06".encode())

if not decrypted_tlv_bytes:
Expand Down
6 changes: 3 additions & 3 deletions pyatv/protocols/airplay/server_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _m1_verify(self, pairing_data):
{TlvValue.Identifier: self.unique_id, TlvValue.Signature: signature}
)

chacha = chacha20.Chacha20Cipher(session_key, session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(session_key, session_key)
encrypted = chacha.encrypt(tlv, nonce="PV-Msg02".encode())

tlv = {
Expand Down Expand Up @@ -368,7 +368,7 @@ def _m5_setup(self, pairing_data, transient: bool):
binascii.unhexlify(self.session.key),
)

chacha = chacha20.Chacha20Cipher(session_key, session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(session_key, session_key)
decrypted_tlv_bytes = chacha.decrypt(
pairing_data[TlvValue.EncryptedData], nonce="PS-Msg05".encode()
)
Expand Down Expand Up @@ -396,7 +396,7 @@ def _m5_setup(self, pairing_data, transient: bool):
}
)

chacha = chacha20.Chacha20Cipher(session_key, session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(session_key, session_key)
encrypted = chacha.encrypt(tlv, nonce="PS-Msg06".encode())

self.has_paired()
Expand Down
2 changes: 1 addition & 1 deletion pyatv/protocols/mrp/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def eof_received(self):

def enable_encryption(self, output_key: bytes, input_key: bytes) -> None:
"""Enable encryption with the specified keys."""
self._chacha = chacha20.Chacha20Cipher(output_key, input_key)
self._chacha = chacha20.Chacha20Cipher8byteNonce(output_key, input_key)

@property
def connected(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pyatv/protocols/mrp/server_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _m1_verify(self, pairing_data):
{TlvValue.Identifier: self.unique_id, TlvValue.Signature: signature}
)

chacha = chacha20.Chacha20Cipher(session_key, session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(session_key, session_key)
encrypted = chacha.encrypt(tlv, nonce="PV-Msg02".encode())

msg = messages.crypto_pairing(
Expand Down Expand Up @@ -227,7 +227,7 @@ def _m5_setup(self, _):
}
)

chacha = chacha20.Chacha20Cipher(session_key, session_key)
chacha = chacha20.Chacha20Cipher8byteNonce(session_key, session_key)
encrypted = chacha.encrypt(tlv, nonce="PS-Msg06".encode())

msg = messages.crypto_pairing(
Expand Down
59 changes: 50 additions & 9 deletions pyatv/support/chacha20.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@

NONCE_LENGTH = 12

# The first 4 bytes are always 0, followed by 8 bytes of counter
# for a total of 12 bytes.
PACK_NONCE = partial(Struct("<LQ").pack, 0)


class Chacha20Cipher:
"""CHACHA20 encryption/decryption layer."""

def __init__(self, out_key, in_key, nonce_length=8):
def __init__(self, out_key: bytes, in_key: bytes, nonce_length: int = 8) -> None:
"""Initialize a new Chacha20Cipher."""
self._enc_out = ChaCha20Poly1305(out_key)
self._enc_in = ChaCha20Poly1305(in_key)
Expand All @@ -31,7 +27,11 @@ def out_nonce(self) -> bytes:
This is the nonce that will be used by encrypt in the _next_ call if no custom
nonce is specified.
"""
return PACK_NONCE(self._out_counter)
nonce_length = self._nonce_length
nonce = self._out_counter.to_bytes(length=nonce_length, byteorder="little")
if nonce_length != NONCE_LENGTH:
return self._pad_nonce(nonce)
return nonce

@property
def in_nonce(self) -> bytes:
Expand All @@ -40,7 +40,15 @@ def in_nonce(self) -> bytes:
This is the nonce that will be used by decrypt in the _next_ call if no custom
nonce is specified.
"""
return PACK_NONCE(self._in_counter)
nonce_length = self._nonce_length
nonce = self._in_counter.to_bytes(length=nonce_length, byteorder="little")
if nonce_length != NONCE_LENGTH:
return self._pad_nonce(nonce)
return nonce

def _pad_nonce(self, nonce: bytes) -> bytes:
"""Pad nonce to 12 bytes."""
return b"\x00" * (NONCE_LENGTH - len(nonce)) + nonce

def encrypt(
self, data: bytes, nonce: Optional[bytes] = None, aad: Optional[bytes] = None
Expand All @@ -50,7 +58,7 @@ def encrypt(
nonce = self.out_nonce
self._out_counter += 1
elif len(nonce) < NONCE_LENGTH:
nonce = b"\x00" * (NONCE_LENGTH - len(nonce)) + nonce
nonce = self._pad_nonce(nonce)
return self._enc_out.encrypt(nonce, data, aad)

def decrypt(
Expand All @@ -61,5 +69,38 @@ def decrypt(
nonce = self.in_nonce
self._in_counter += 1
elif len(nonce) < NONCE_LENGTH:
nonce = b"\x00" * (NONCE_LENGTH - len(nonce)) + nonce
nonce = self._pad_nonce(nonce)
return self._enc_in.decrypt(nonce, data, aad)


_PACK_NONCE_WITH_4_BYTE_PAD = partial(Struct("<LQ").pack, 0)


class Chacha20Cipher8byteNonce(Chacha20Cipher):
"""CHACHA20 encryption/decryption layer with an 8 byte counter.
The first 4 bytes are always 0, followed by 8 bytes of counter
for a total of 12 bytes.
"""

def __init__(self, out_key: bytes, in_key: bytes) -> None:
"""Initialize a new Chacha20Cipher8byteNonce."""
super().__init__(out_key, in_key, nonce_length=8)

@property
def out_nonce(self) -> bytes:
"""Return next encrypt nonce.
This is the nonce that will be used by encrypt in the _next_ call if no custom
nonce is specified.
"""
return _PACK_NONCE_WITH_4_BYTE_PAD(self._out_counter)

@property
def in_nonce(self) -> bytes:
"""Return next decrypt nonce.
This is the nonce that will be used by decrypt in the _next_ call if no custom
nonce is specified.
"""
return _PACK_NONCE_WITH_4_BYTE_PAD(self._in_counter)
2 changes: 1 addition & 1 deletion tests/fake_device/mrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def connection_lost(self, exc):

def enable_encryption(self, output_key: bytes, input_key: bytes) -> None:
"""Enable encryption with specified keys."""
self.chacha = chacha20.Chacha20Cipher(output_key, input_key)
self.chacha = chacha20.Chacha20Cipher8byteNonce(output_key, input_key)
self.state.has_authenticated = True

def send_to_client(self, message: ProtobufMessage) -> None:
Expand Down
23 changes: 23 additions & 0 deletions tests/support/test_chacha20.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Unit tests for pyatv.support.chacha20."""

import logging

from pyatv.support import chacha20

fake_key = b"k" * 32


def test_12_bytes_nonce():
cipher = chacha20.Chacha20Cipher(fake_key, fake_key, 12)
assert len(cipher.out_nonce) == chacha20.NONCE_LENGTH
assert len(cipher.in_nonce) == chacha20.NONCE_LENGTH
result = cipher.encrypt(b"test")
assert cipher.decrypt(result) == b"test"


def test_8_bytes_nonce():
cipher = chacha20.Chacha20Cipher8byteNonce(fake_key, fake_key)
assert len(cipher.out_nonce) == chacha20.NONCE_LENGTH
assert len(cipher.in_nonce) == chacha20.NONCE_LENGTH
result = cipher.encrypt(b"test")
assert cipher.decrypt(result) == b"test"

0 comments on commit ea9d333

Please sign in to comment.