diff --git a/beanie/__init__.py b/beanie/__init__.py index 91f233de..5a41b445 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -16,7 +16,11 @@ from beanie.odm.bulk import BulkWriter from beanie.odm.custom_types import DecimalAnnotation from beanie.odm.custom_types.bson.binary import BsonBinary -from beanie.odm.documents import Document, MergeStrategy +from beanie.odm.documents import ( + Document, + DocumentWithSoftDelete, + MergeStrategy, +) from beanie.odm.enums import SortDirection from beanie.odm.fields import ( BackLink, @@ -37,6 +41,7 @@ __all__ = [ # ODM "Document", + "DocumentWithSoftDelete", "View", "UnionDoc", "init_beanie", diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 0523b13b..8389880b 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -1,7 +1,9 @@ import asyncio import warnings +from datetime import datetime from enum import Enum from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -12,6 +14,7 @@ List, Mapping, Optional, + Tuple, Type, TypeVar, Union, @@ -52,6 +55,7 @@ ) from beanie.odm.bulk import BulkWriter, Operation from beanie.odm.cache import LRUCache +from beanie.odm.enums import SortDirection from beanie.odm.fields import ( BackLink, DeleteRules, @@ -83,6 +87,7 @@ from beanie.odm.operators.update.general import ( Set as SetOperator, ) +from beanie.odm.queries.find import FindMany, FindOne from beanie.odm.queries.update import UpdateMany, UpdateResponse from beanie.odm.settings.document import DocumentSettings from beanie.odm.utils.dump import get_dict, get_top_level_nones @@ -107,6 +112,10 @@ if IS_PYDANTIC_V2: from pydantic import model_validator +if TYPE_CHECKING: + from beanie.odm.views import View + +FindType = TypeVar("FindType", bound=Union["Document", "View"]) DocType = TypeVar("DocType", bound="Document") P = ParamSpec("P") R = TypeVar("R") @@ -1199,3 +1208,132 @@ async def distinct( def link_from_id(cls, id: Any): ref = DBRef(id=id, collection=cls.get_collection_name()) return Link(ref, document_class=cls) + + +class DocumentWithSoftDelete(Document): + deleted_at: Optional[datetime] = None + + def is_deleted(self) -> bool: + return self.deleted_at is not None + + async def hard_delete( + self, + session: Optional[ClientSession] = None, + bulk_writer: Optional[BulkWriter] = None, + link_rule: DeleteRules = DeleteRules.DO_NOTHING, + skip_actions: Optional[List[Union[ActionDirections, str]]] = None, + **pymongo_kwargs, + ) -> Optional[DeleteResult]: + return await super().delete( + session=session, + bulk_writer=bulk_writer, + link_rule=link_rule, + skip_actions=skip_actions, + **pymongo_kwargs, + ) + + async def delete( + self, + session: Optional[ClientSession] = None, + bulk_writer: Optional[BulkWriter] = None, + link_rule: DeleteRules = DeleteRules.DO_NOTHING, + skip_actions: Optional[List[Union[ActionDirections, str]]] = None, + **pymongo_kwargs, + ) -> Optional[DeleteResult]: + self.deleted_at = datetime.utcnow() + await self.save() + return None + + @classmethod + def find_many_in_all( # type: ignore + cls: Type[FindType], + *args: Union[Mapping[str, Any], bool], + projection_model: None = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, + session: Optional[ClientSession] = None, + ignore_cache: bool = False, + fetch_links: bool = False, + with_children: bool = False, + lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, + **pymongo_kwargs, + ) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]: + return cls._find_many_query_class(document_model=cls).find_many( + *args, + sort=sort, + skip=skip, + limit=limit, + projection_model=projection_model, + session=session, + ignore_cache=ignore_cache, + fetch_links=fetch_links, + lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, + **pymongo_kwargs, + ) + + @classmethod + def find_many( # type: ignore + cls: Type[FindType], + *args: Union[Mapping[str, Any], bool], + projection_model: Optional[Type["DocumentProjectionType"]] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, + session: Optional[ClientSession] = None, + ignore_cache: bool = False, + fetch_links: bool = False, + with_children: bool = False, + lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, + **pymongo_kwargs, + ) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]: + args = cls._add_class_id_filter(args, with_children) + ( + {"deleted_at": None}, + ) + return cls._find_many_query_class(document_model=cls).find_many( + *args, + sort=sort, + skip=skip, + limit=limit, + projection_model=projection_model, + session=session, + ignore_cache=ignore_cache, + fetch_links=fetch_links, + lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, + **pymongo_kwargs, + ) + + @classmethod + def find_one( # type: ignore + cls: Type[FindType], + *args: Union[Mapping[str, Any], bool], + projection_model: Optional[Type["DocumentProjectionType"]] = None, + session: Optional[ClientSession] = None, + ignore_cache: bool = False, + fetch_links: bool = False, + with_children: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, + **pymongo_kwargs, + ) -> Union[FindOne[FindType], FindOne["DocumentProjectionType"]]: + args = cls._add_class_id_filter(args, with_children) + ( + {"deleted_at": None}, + ) + return cls._find_one_query_class(document_model=cls).find_one( + *args, + projection_model=projection_model, + session=session, + ignore_cache=ignore_cache, + fetch_links=fetch_links, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, + **pymongo_kwargs, + ) diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 085f45b7..afb582e7 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -27,6 +27,7 @@ DocumentTestModelWithIndexFlagsAliases, DocumentTestModelWithLink, DocumentTestModelWithSimpleIndex, + DocumentTestModelWithSoftDelete, DocumentToBeLinked, DocumentToTestSync, DocumentUnion, @@ -199,6 +200,7 @@ async def init(db): DocumentWithExtras, DocumentWithPydanticConfig, DocumentTestModel, + DocumentTestModelWithSoftDelete, DocumentTestModelWithLink, DocumentTestModelWithCustomCollectionName, DocumentTestModelWithSimpleIndex, @@ -333,6 +335,27 @@ def generate_documents( return generate_documents +@pytest.fixture +def document_soft_delete_not_inserted(): + return DocumentTestModelWithSoftDelete( + test_int=randint(0, 1000000), + test_str="kipasa", + ) + + +@pytest.fixture +def documents_soft_delete_not_inserted(): + docs = [] + for i in range(3): + docs.append( + DocumentTestModelWithSoftDelete( + test_int=randint(0, 1000000), + test_str="kipasa", + ) + ) + return docs + + @pytest.fixture async def document(document_not_inserted) -> DocumentTestModel: return await document_not_inserted.insert() diff --git a/tests/odm/documents/test_soft_delete.py b/tests/odm/documents/test_soft_delete.py new file mode 100644 index 00000000..d3cb4f95 --- /dev/null +++ b/tests/odm/documents/test_soft_delete.py @@ -0,0 +1,110 @@ +from tests.odm.models import DocumentTestModelWithSoftDelete + + +async def test_get_item(document_soft_delete_not_inserted): + # insert a document with soft delete + result = await document_soft_delete_not_inserted.insert() + + # get from db by id + document = await DocumentTestModelWithSoftDelete.get(document_id=result.id) + + assert document.is_deleted() is False + assert document.deleted_at is None + assert document.test_int == result.test_int + assert document.test_str == result.test_str + + # # delete the document + await document.delete() + assert document.is_deleted() is True + + # check if document exist with `.get()` + document = await DocumentTestModelWithSoftDelete.get(document_id=result.id) + assert document is None + + # check document exist in trashed + results = ( + await DocumentTestModelWithSoftDelete.find_many_in_all().to_list() + ) + assert len(results) == 1 + + +async def test_find_one(document_soft_delete_not_inserted): + result = await document_soft_delete_not_inserted.insert() + + # # delete the document + await result.delete() + + # check if document exist with `.find_one()` + document = await DocumentTestModelWithSoftDelete.find_one( + DocumentTestModelWithSoftDelete.id == result.id + ) + assert document is None + + +async def test_find(documents_soft_delete_not_inserted): + # insert 3 documents + inserted_docs = [] + for doc in documents_soft_delete_not_inserted: + result = await doc.insert() + inserted_docs.append(result) + + # use `.find_many()` to get them all + results = await DocumentTestModelWithSoftDelete.find().to_list() + assert len(results) == 3 + + # delete one of them + await inserted_docs[0].delete() + + # check items in with `.find_many()` + results = await DocumentTestModelWithSoftDelete.find_many().to_list() + + assert len(results) == 2 + + founded_documents_id = [doc.id for doc in results] + assert inserted_docs[0].id not in founded_documents_id + + # check in trashed items + results = ( + await DocumentTestModelWithSoftDelete.find_many_in_all().to_list() + ) + assert len(results) == 3 + + +async def test_find_many(documents_soft_delete_not_inserted): + # insert 2 documents + item_1 = await documents_soft_delete_not_inserted[0].insert() + item_2 = await documents_soft_delete_not_inserted[1].insert() + + # use `.find_many()` to get them all + results = await DocumentTestModelWithSoftDelete.find_many().to_list() + assert len(results) == 2 + + # delete one of them + await item_1.delete() + + # check items in with `.find_many()` + results = await DocumentTestModelWithSoftDelete.find_many().to_list() + + assert len(results) == 1 + assert results[0].id == item_2.id + + # check in trashed items + results = ( + await DocumentTestModelWithSoftDelete.find_many_in_all().to_list() + ) + assert len(results) == 2 + + +async def test_hard_delete(document_soft_delete_not_inserted): + result = await document_soft_delete_not_inserted.insert() + await result.hard_delete() + + # check items in with `.find_many()` + results = await DocumentTestModelWithSoftDelete.find_many().to_list() + assert len(results) == 0 + + # check in trashed + results = ( + await DocumentTestModelWithSoftDelete.find_many_in_all().to_list() + ) + assert len(results) == 0 diff --git a/tests/odm/models.py b/tests/odm/models.py index c1664f78..223568af 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -42,6 +42,7 @@ from beanie import ( DecimalAnnotation, Document, + DocumentWithSoftDelete, Indexed, Insert, Replace, @@ -140,6 +141,11 @@ class Sample(Document): const: str = "TEST" +class DocumentTestModelWithSoftDelete(DocumentWithSoftDelete): + test_int: int + test_str: str + + class SubDocument(BaseModel): test_str: str test_int: int = 42