diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py index 5e3e4cd5..0b6360fd 100644 --- a/beanie/odm/queries/find.py +++ b/beanie/odm/queries/find.py @@ -940,9 +940,7 @@ async def replace_one( Operation( operation=ReplaceOne, first_query=self.get_filter_query(), - second_query=Encoder( - by_alias=True, exclude={"_id"} - ).encode(document), + second_query=Encoder(exclude={"_id"}).encode(document), object_class=self.document_model, pymongo_kwargs=self.pymongo_kwargs, ) diff --git a/beanie/odm/utils/dump.py b/beanie/odm/utils/dump.py index af8bc4f3..a030fc9a 100644 --- a/beanie/odm/utils/dump.py +++ b/beanie/odm/utils/dump.py @@ -18,9 +18,8 @@ def get_dict( exclude.add("_id") if not document.get_settings().use_revision: exclude.add("revision_id") - return Encoder( - by_alias=True, exclude=exclude, to_db=to_db, keep_nulls=keep_nulls - ).encode(document) + encoder = Encoder(exclude=exclude, to_db=to_db, keep_nulls=keep_nulls) + return encoder.encode(document) def get_nulls( diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index b33bbbd8..4d7a06a5 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -1,252 +1,163 @@ -import re -from collections import deque -from datetime import datetime, timedelta -from decimal import Decimal -from enum import Enum -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, -) -from pathlib import PurePath -from types import GeneratorType -from typing import ( - AbstractSet, - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Type, - Union, -) -from uuid import UUID - -import bson -from bson import Binary, DBRef, Decimal128, ObjectId, Regex -from pydantic import BaseModel, SecretBytes, SecretStr - -from beanie.odm import documents -from beanie.odm.fields import Link, LinkTypes -from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, get_iterator - -if IS_PYDANTIC_V2: - from pydantic import RootModel - -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { - timedelta: lambda td: td.total_seconds(), - Decimal: Decimal128, - deque: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - SecretBytes: SecretBytes.get_secret_value, - SecretStr: SecretStr.get_secret_value, - Enum: lambda o: o.value, - PurePath: str, - Link: lambda l: l.ref, # noqa: E741 - bytes: lambda b: b if isinstance(b, Binary) else Binary(b), - UUID: lambda u: bson.Binary.from_uuid(u), - re.Pattern: Regex.from_native, -} - - -class Ignore: - ... - - -IGNORE = Ignore() - - -class Encoder: - """ - BSON encoding class - """ - - def __init__( - self, - exclude: Union[ - AbstractSet[Union[str, int]], Mapping[Union[str, int], Any], None - ] = None, - custom_encoders: Optional[Dict[Type, Callable]] = None, - by_alias: bool = True, - to_db: bool = False, - keep_nulls: bool = True, - ): - self.exclude = exclude or {} - self.by_alias = by_alias - self.custom_encoders = custom_encoders or {} - self.to_db = to_db - self.keep_nulls = keep_nulls - - def encode(self, obj: Any): - """ - Run the encoder - """ - return self._encode(obj=obj) - - def encode_document(self, obj): - """ - Beanie Document class case - """ - obj.parse_store() - - encoder = Encoder( - custom_encoders=obj.get_settings().bson_encoders, - by_alias=self.by_alias, - to_db=self.to_db, - keep_nulls=self.keep_nulls, - ) - - link_fields = obj.get_link_fields() - obj_dict: Dict[str, Any] = {} - if obj.get_settings().union_doc is not None: - obj_dict[obj.get_settings().class_id] = ( - obj.get_settings().union_doc_alias or obj.__class__.__name__ - ) - if obj._inheritance_inited: - obj_dict[obj.get_settings().class_id] = obj._class_id - - for k, o in get_iterator(obj, by_alias=self.by_alias): - if k not in self.exclude and ( - self.keep_nulls is True or o is not None - ): - if link_fields and k in link_fields: - if link_fields[k].link_type == LinkTypes.LIST: - obj_dict[k] = [link.to_ref() for link in o] - if link_fields[k].link_type == LinkTypes.DIRECT: - obj_dict[k] = o.to_ref() - if link_fields[k].link_type == LinkTypes.OPTIONAL_DIRECT: - if o is not None: - obj_dict[k] = o.to_ref() - else: - obj_dict[k] = o - if link_fields[k].link_type == LinkTypes.OPTIONAL_LIST: - if o is not None: - obj_dict[k] = [link.to_ref() for link in o] - else: - obj_dict[k] = o - if ( - link_fields[k].link_type == LinkTypes.BACK_DIRECT - and self.to_db - ): - obj_dict[k] = IGNORE - if ( - link_fields[k].link_type == LinkTypes.BACK_LIST - and self.to_db - ): - obj_dict[k] = IGNORE - if ( - link_fields[k].link_type - == LinkTypes.OPTIONAL_BACK_DIRECT - and self.to_db - ): - obj_dict[k] = IGNORE - if ( - link_fields[k].link_type - == LinkTypes.OPTIONAL_BACK_LIST - and self.to_db - ): - obj_dict[k] = IGNORE - else: - obj_dict[k] = o - - if isinstance(obj_dict[k], Ignore) and obj_dict[k] == IGNORE: - # Check the class, as direct comparison might not work, like with numpy arrays - del obj_dict[k] - else: - obj_dict[k] = encoder.encode(obj_dict[k]) - return obj_dict - - def encode_base_model(self, obj): - """ - BaseModel case - """ - obj_dict = {} - for k, o in get_iterator(obj, by_alias=self.by_alias): - if k not in self.exclude and ( - self.keep_nulls is True or o is not None - ): - obj_dict[k] = self._encode(o) - - return obj_dict - - def encode_root_model(self, obj): - """ - RootModel case - """ - return self._encode(obj.root) - - def encode_dict(self, obj): - """ - Dictionary case - """ - return {key: self._encode(value) for key, value in obj.items()} - - def encode_iterable(self, obj): - """ - Iterable case - """ - return [self._encode(item) for item in obj] - - def _encode( - self, - obj, - ) -> Any: - """""" - if self.custom_encoders: - if type(obj) in self.custom_encoders: - return self.custom_encoders[type(obj)](obj) - for encoder_type, encoder in self.custom_encoders.items(): - if isinstance(obj, encoder_type): - return encoder(obj) - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for cls, encoder in ENCODERS_BY_TYPE.items(): - if isinstance(obj, cls): - return encoder(obj) - - if isinstance(obj, documents.Document): - return self.encode_document(obj) - if IS_PYDANTIC_V2 and isinstance(obj, RootModel): - return self.encode_root_model(obj) - if isinstance(obj, BaseModel): - return self.encode_base_model(obj) - if isinstance(obj, dict): - return self.encode_dict(obj) - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): - return self.encode_iterable(obj) - - if isinstance( - obj, - ( - str, - int, - float, - ObjectId, - datetime, - type(None), - DBRef, - Decimal128, - ), - ): - return obj - - errors: List[Exception] = [] - try: - data = dict(obj) - except Exception as e: - errors.append(e) - try: - data = vars(obj) - except Exception as e: - errors.append(e) - raise ValueError(errors) - return self._encode(data) +import dataclasses as dc +import datetime +import decimal +import enum +import ipaddress +import operator +import pathlib +import re +import uuid +from typing import Any, Callable, Container, Iterable, Mapping, Optional, Tuple + +import bson +import pydantic + +import beanie +from beanie.odm.fields import Link, LinkTypes +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 + +DEFAULT_CUSTOM_ENCODERS: Mapping[type, Callable] = { + ipaddress.IPv4Address: str, + ipaddress.IPv4Interface: str, + ipaddress.IPv4Network: str, + ipaddress.IPv6Address: str, + ipaddress.IPv6Interface: str, + ipaddress.IPv6Network: str, + pathlib.PurePath: str, + pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value, + pydantic.SecretStr: pydantic.SecretStr.get_secret_value, + datetime.timedelta: operator.methodcaller("total_seconds"), + enum.Enum: operator.attrgetter("value"), + Link: operator.attrgetter("ref"), + bytes: bson.Binary, + decimal.Decimal: bson.Decimal128, + uuid.UUID: bson.Binary.from_uuid, + re.Pattern: bson.Regex.from_native, +} +BSON_SCALAR_TYPES = ( + type(None), + str, + int, + float, + datetime.datetime, + bson.Binary, + bson.DBRef, + bson.Decimal128, + bson.ObjectId, +) + + +@dc.dataclass +class Encoder: + """ + BSON encoding class + """ + + exclude: Container[str] = frozenset() + custom_encoders: Mapping[type, Callable] = dc.field(default_factory=dict) + to_db: bool = False + keep_nulls: bool = True + + def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]: + obj.parse_store() + settings = obj.get_settings() + obj_dict = {} + if settings.union_doc is not None: + obj_dict[settings.class_id] = ( + settings.union_doc_alias or obj.__class__.__name__ + ) + if obj._class_id: + obj_dict[settings.class_id] = obj._class_id + + link_fields = obj.get_link_fields() or {} + sub_encoder = Encoder( + # don't propagate self.exclude to subdocuments + custom_encoders=settings.bson_encoders, + to_db=self.to_db, + keep_nulls=self.keep_nulls, + ) + for key, value in self._iter_model_items(obj): + if key in link_fields: + link_type = link_fields[key].link_type + if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT): + if value is not None: + value = value.to_ref() + elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST): + if value is not None: + value = [link.to_ref() for link in value] + elif self.to_db: + continue + obj_dict[key] = sub_encoder.encode(value) + return obj_dict + + def encode(self, obj: Any) -> Any: + if self.custom_encoders: + encoder = _get_encoder(obj, self.custom_encoders) + if encoder is not None: + return encoder(obj) + + if isinstance(obj, BSON_SCALAR_TYPES): + return obj + + encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS) + if encoder is not None: + return encoder(obj) + + if isinstance(obj, beanie.Document): + return self._encode_document(obj) + if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel): + return self.encode(obj.root) + if isinstance(obj, pydantic.BaseModel): + items = self._iter_model_items(obj) + return {key: self.encode(value) for key, value in items} + if isinstance(obj, Mapping): + return {key: self.encode(value) for key, value in obj.items()} + if isinstance(obj, Iterable): + return [self.encode(value) for value in obj] + + errors = [] + try: + data = dict(obj) + except Exception as e: + errors.append(e) + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) + return self.encode(data) + + if IS_PYDANTIC_V2: + + def _iter_model_items( + self, obj: pydantic.BaseModel + ) -> Iterable[Tuple[str, Any]]: + exclude, keep_nulls = self.exclude, self.keep_nulls + for key, value in obj.__iter__(): + field_info = obj.model_fields.get(key) + if field_info is not None: + key = field_info.alias or key + if key not in exclude and (value is not None or keep_nulls): + yield key, value + + else: + + def _iter_model_items( + self, obj: pydantic.BaseModel + ) -> Iterable[Tuple[str, Any]]: + exclude, keep_nulls = self.exclude, self.keep_nulls + for key, value in obj._iter(to_dict=False, by_alias=True): + if key not in exclude and (value is not None or keep_nulls): + yield key, value + + +def _get_encoder( + obj: Any, custom_encoders: Mapping[type, Callable] +) -> Optional[Callable]: + encoder = custom_encoders.get(type(obj)) + if encoder is not None: + return encoder + for cls, encoder in custom_encoders.items(): + if isinstance(obj, cls): + return encoder + return None diff --git a/beanie/odm/utils/pydantic.py b/beanie/odm/utils/pydantic.py index c05f5535..f2446eb3 100644 --- a/beanie/odm/utils/pydantic.py +++ b/beanie/odm/utils/pydantic.py @@ -60,25 +60,3 @@ def get_model_dump(model): return model.model_dump() else: return model.dict() - - -def get_iterator(model, by_alias=False): - if IS_PYDANTIC_V2: - - def _get_alias(model, k): - v = model.model_fields.get(k) - if v is not None: - return v.alias or k - else: - return k - - def _iter(model, by_alias=False): - for k, v in model.__iter__(): - if by_alias: - yield _get_alias(model, k), v - else: - yield k, v - - return _iter(model, by_alias=by_alias) - else: - return model._iter(to_dict=False, by_alias=by_alias)