Skip to content

Commit

Permalink
✨ Add Base64 and a general Encoded types (pydantic#5615)
Browse files Browse the repository at this point in the history
  • Loading branch information
lig committed May 4, 2023
1 parent 97e6138 commit 643e332
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 3 deletions.
1 change: 1 addition & 0 deletions changes/692-lig.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `EncodedBytes` and `EncodedStr` general types and implement `base64` encoding support in `Base64Bytes` and `Base64Str` types.
130 changes: 130 additions & 0 deletions docs/usage/types/encoded.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
---
description: Support for storing data in an encoded form.
---

`EncodedBytes`
: a bytes value which is decoded from/encoded into a (different) bytes value during validation/serialization

`EncodedStr`
: a string value which is decoded from/encoded into a (different) string value during validation/serialization

`EncodedBytes` and `EncodedStr` needs an encoder that implements `EncoderProtocol` to operate.

```py
from typing import Optional

from typing_extensions import Annotated

from pydantic import (
BaseModel,
EncodedBytes,
EncodedStr,
EncoderProtocol,
ValidationError,
)


class MyEncoder(EncoderProtocol):
@classmethod
def decode(cls, data: bytes) -> bytes:
if data == b'**undecodable**':
raise ValueError('Cannot decode data')
return data[13:]

@classmethod
def encode(cls, value: bytes) -> bytes:
return b'**encoded**: ' + value

@classmethod
def get_json_format(cls) -> str:
return 'my-encoder'


MyEncodedBytes = Annotated[bytes, EncodedBytes(encoder=MyEncoder)]
MyEncodedStr = Annotated[str, EncodedStr(encoder=MyEncoder)]


class Model(BaseModel):
my_encoded_bytes: MyEncodedBytes
my_encoded_str: Optional[MyEncodedStr] = None


# Initialize the model with encoded data
m = Model(
my_encoded_bytes=b'**encoded**: some bytes', my_encoded_str='**encoded**: some str'
)

# Access decoded value
print(m.my_encoded_bytes)
#> b'some bytes'
print(m.my_encoded_str)
#> some str

# Serialize into the encoded form
print(m.model_dump())
"""
{
'my_encoded_bytes': b'**encoded**: some bytes',
'my_encoded_str': '**encoded**: some str',
}
"""

# Validate encoded data
try:
Model(my_encoded_bytes=b'**undecodable**')
except ValidationError as e:
print(e)
"""
1 validation error for Model
my_encoded_bytes
Value error, Cannot decode data [type=value_error, input_value=b'**undecodable**', input_type=bytes]
"""
```

## Base64 encoding support

Internally, pydantic uses the `pydantic.types.EncodedBytes` and `pydantic.types.EncodedStr` annotations with `pydantic.types.Base64Encoder` to implement base64 encoding/decoding in the `Base64Bytes` and `Base64Str` types, respectively.

```py
from typing import Optional

from pydantic import Base64Bytes, Base64Str, BaseModel, ValidationError


class Model(BaseModel):
base64_bytes: Base64Bytes
base64_str: Optional[Base64Str] = None


# Initialize the model with base64 data
m = Model(
base64_bytes=b'VGhpcyBpcyB0aGUgd2F5',
base64_str='VGhlc2UgYXJlbid0IHRoZSBkcm9pZHMgeW91J3JlIGxvb2tpbmcgZm9y',
)

# Access decoded value
print(m.base64_bytes)
#> b'This is the way'
print(m.base64_str)
#> These aren't the droids you're looking for

# Serialize into the base64 form
print(m.model_dump())
"""
{
'base64_bytes': b'VGhpcyBpcyB0aGUgd2F5\n',
'base64_str': 'VGhlc2UgYXJlbid0IHRoZSBkcm9pZHMgeW91J3JlIGxvb2tpbmcgZm9y\n',
}
"""

# Validate base64 data
try:
print(Model(base64_bytes=b'undecodable').base64_bytes)
except ValidationError as e:
print(e)
"""
1 validation error for Model
base64_bytes
Base64 decoding error: 'Incorrect padding' [type=base64_decode, input_value=b'undecodable', input_type=bytes]
"""
```
1 change: 1 addition & 0 deletions docs/usage/types/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ The following sections describe the types supported by Pydantic.
* [Unions](unions.md) — allows a model attribute to accept different types.
* [URLs](urls.md) — URI/URL validation types.
* [UUIDs](uuids.md) — types that allow you to store UUIDs in your model.
* [Base64 and other encodings](encoded.md) — types that allow serializing values into an encoded form, e.g. `base64`.
* [Custom Data Types](custom.md) — create your own custom data types.
* [Field Type Conversions](../conversion_table.md) — strict and lax conversion between different field types.
6 changes: 6 additions & 0 deletions pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
'AwareDatetime',
'NaiveDatetime',
'AllowInfNan',
'EncoderProtocol',
'EncodedBytes',
'EncodedStr',
'Base64Encoder',
'Base64Bytes',
'Base64Str',
# version
'VERSION',
]
Expand Down
102 changes: 99 additions & 3 deletions pydantic/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import abc
import base64
import dataclasses as _dataclasses
import re
from datetime import date, datetime
Expand All @@ -10,6 +11,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
FrozenSet,
Generic,
Expand All @@ -23,9 +25,9 @@

import annotated_types
from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated, Literal, Protocol

from ._internal import _fields, _validators
from ._internal import _fields, _internal_dataclass, _validators
from ._migration import getattr_migration
from .annotated import GetCoreSchemaHandler
from .errors import PydanticUserError
Expand Down Expand Up @@ -74,6 +76,12 @@
'AwareDatetime',
'NaiveDatetime',
'AllowInfNan',
'EncoderProtocol',
'EncodedBytes',
'EncodedStr',
'Base64Encoder',
'Base64Bytes',
'Base64Str',
]

from ._internal._core_metadata import build_metadata_dict
Expand Down Expand Up @@ -318,7 +326,7 @@ def condecimal(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


@_dataclasses.dataclass(frozen=True) # Add frozen=True to make it hashable
@_internal_dataclass.slots_dataclass
class UuidVersion:
uuid_version: Literal[1, 3, 4, 5]

Expand All @@ -342,6 +350,9 @@ def validate(self, value: UUID, _: core_schema.ValidationInfo) -> UUID:
)
return value

def __hash__(self) -> int:
return hash(type(self.uuid_version))


UUID1 = Annotated[UUID, UuidVersion(1)]
UUID3 = Annotated[UUID, UuidVersion(3)]
Expand Down Expand Up @@ -882,4 +893,89 @@ def __repr__(self) -> str:
return 'NaiveDatetime'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Encoded TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


class EncoderProtocol(Protocol):
@classmethod
def decode(cls, data: bytes) -> bytes:
"""Can throw `PydanticCustomError`"""
...

@classmethod
def encode(cls, value: bytes) -> bytes:
...

@classmethod
def get_json_format(cls) -> str:
...


class Base64Encoder(EncoderProtocol):
@classmethod
def decode(cls, data: bytes) -> bytes:
try:
return base64.decodebytes(data)
except ValueError as e:
raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)})

@classmethod
def encode(cls, value: bytes) -> bytes:
return base64.encodebytes(value)

@classmethod
def get_json_format(cls) -> str:
return 'base64'


@_internal_dataclass.slots_dataclass
class EncodedBytes:
encoder: type[EncoderProtocol]

def __get_pydantic_json_schema__(
self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
field_schema = handler(core_schema)
field_schema.update(type='string', format=self.encoder.get_json_format())
return field_schema

def __get_pydantic_core_schema__(
self, source: type[Any], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.general_after_validator_function(
function=self.decode,
schema=core_schema.bytes_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(function=self.encode),
)

def decode(self, data: bytes, _: core_schema.ValidationInfo) -> bytes:
return self.encoder.decode(data)

def encode(self, value: bytes) -> bytes:
return self.encoder.encode(value)

def __hash__(self) -> int:
return hash(self.encoder)


class EncodedStr(EncodedBytes):
def __get_pydantic_core_schema__(
self, source: type[Any], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.general_after_validator_function(
function=self.decode_str,
schema=super().__get_pydantic_core_schema__(source=source, handler=handler),
serialization=core_schema.plain_serializer_function_ser_schema(function=self.encode_str),
)

def decode_str(self, data: bytes, _: core_schema.ValidationInfo) -> str:
return data.decode()

def encode_str(self, value: str) -> str:
return super().encode(value=value.encode()).decode()


Base64Bytes = Annotated[bytes, EncodedBytes(encoder=Base64Encoder)]
Base64Str = Annotated[str, EncodedStr(encoder=Base64Encoder)]

__getattr__ = getattr_migration(__name__)
81 changes: 81 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
UUID5,
AnalyzedType,
AwareDatetime,
Base64Bytes,
Base64Str,
BaseModel,
ByteSize,
ConfigDict,
Expand Down Expand Up @@ -4351,3 +4353,82 @@ class Model(BaseModel):
'type': 'int_parsing',
}
]


@pytest.mark.parametrize(
('field_type', 'input_data', 'expected_value', 'serialized_data'),
[
pytest.param(Base64Bytes, b'Zm9vIGJhcg==\n', b'foo bar', b'Zm9vIGJhcg==\n', id='Base64Bytes-reversible'),
pytest.param(Base64Str, 'Zm9vIGJhcg==\n', 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-reversible'),
pytest.param(Base64Bytes, b'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==\n', id='Base64Bytes-bytes-input'),
pytest.param(Base64Bytes, 'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==\n', id='Base64Bytes-str-input'),
pytest.param(
Base64Bytes, bytearray(b'Zm9vIGJhcg=='), b'foo bar', b'Zm9vIGJhcg==\n', id='Base64Bytes-bytearray-input'
),
pytest.param(Base64Str, b'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-bytes-input'),
pytest.param(Base64Str, 'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-str-input'),
pytest.param(
Base64Str, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-bytearray-input'
),
],
)
def test_base64(field_type, input_data, expected_value, serialized_data):
class Model(BaseModel):
base64_value: field_type
base64_value_or_none: Optional[field_type] = None

m = Model(base64_value=input_data)
assert m.base64_value == expected_value

m = Model.model_construct(base64_value=expected_value)
assert m.base64_value == expected_value

assert m.model_dump() == {
'base64_value': serialized_data,
'base64_value_or_none': None,
}

assert Model.model_json_schema() == {
'properties': {
'base64_value': {
'format': 'base64',
'title': 'Base64 Value',
'type': 'string',
},
'base64_value_or_none': {
'anyOf': [{'type': 'string', 'format': 'base64'}, {'type': 'null'}],
'default': None,
'title': 'Base64 Value Or None',
},
},
'required': ['base64_value'],
'title': 'Model',
'type': 'object',
}


@pytest.mark.parametrize(
('field_type', 'input_data'),
[
pytest.param(Base64Bytes, b'Zm9vIGJhcg', id='Base64Bytes-invalid-base64-bytes'),
pytest.param(Base64Bytes, 'Zm9vIGJhcg', id='Base64Bytes-invalid-base64-str'),
pytest.param(Base64Str, b'Zm9vIGJhcg', id='Base64Str-invalid-base64-bytes'),
pytest.param(Base64Str, 'Zm9vIGJhcg', id='Base64Str-invalid-base64-str'),
],
)
def test_base64_invalid(field_type, input_data):
class Model(BaseModel):
base64_value: field_type

with pytest.raises(ValidationError) as e:
Model(base64_value=input_data)

assert e.value.errors() == [
{
'ctx': {'error': 'Incorrect padding'},
'input': input_data,
'loc': ('base64_value',),
'msg': "Base64 decoding error: 'Incorrect padding'",
'type': 'base64_decode',
},
]

0 comments on commit 643e332

Please sign in to comment.