Skip to content

Commit

Permalink
feat(mongodb): implement data filter for mongodb source (#529)
Browse files Browse the repository at this point in the history
* feat(mongodb): implement data filter for mongodb source
  • Loading branch information
IlyaFaer authored Sep 22, 2024
1 parent b304784 commit 69a1c5e
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 19 deletions.
16 changes: 14 additions & 2 deletions sources/mongodb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Source that loads collections form any a mongo database, supports incremental loads."""

from typing import Any, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional

import dlt
from dlt.common.data_writers import TDataItemFormat
Expand All @@ -23,6 +23,7 @@ def mongodb(
write_disposition: Optional[str] = dlt.config.value,
parallel: Optional[bool] = dlt.config.value,
limit: Optional[int] = None,
filter_: Optional[Dict[str, Any]] = None,
) -> Iterable[DltResource]:
"""
A DLT source which loads data from a mongo database using PyMongo.
Expand All @@ -39,6 +40,7 @@ def mongodb(
limit (Optional[int]):
The maximum number of documents to load. The limit is
applied to each requested collection separately.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -64,7 +66,14 @@ def mongodb(
primary_key="_id",
write_disposition=write_disposition,
spec=MongoDbCollectionConfiguration,
)(client, collection, incremental=incremental, parallel=parallel, limit=limit)
)(
client,
collection,
incremental=incremental,
parallel=parallel,
limit=limit,
filter_=filter_ or {},
)


@dlt.common.configuration.with_config(
Expand All @@ -80,6 +89,7 @@ def mongodb_collection(
limit: Optional[int] = None,
chunk_size: Optional[int] = 10000,
data_item_format: Optional[TDataItemFormat] = "object",
filter_: Optional[Dict[str, Any]] = None,
) -> Any:
"""
A DLT source which loads a collection from a mongo database using PyMongo.
Expand All @@ -98,6 +108,7 @@ def mongodb_collection(
Supported formats:
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -124,4 +135,5 @@ def mongodb_collection(
limit=limit,
chunk_size=chunk_size,
data_item_format=data_item_format,
filter_=filter_ or {},
)
99 changes: 82 additions & 17 deletions sources/mongodb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _filter_op(self) -> Dict[str, Any]:

return filt

def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> Cursor: # type: ignore
def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> TCursor: # type: ignore
"""Apply a limit to the cursor, if needed.
Args:
Expand All @@ -120,16 +120,23 @@ def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> Cursor: # type

return cursor

def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[TDataItem]:
"""Construct the query and load the documents from the collection.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
cursor = self.collection.find(self._filter_op)
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find(filter=filter_op)
if self._sort_op:
cursor = cursor.sort(self._sort_op)

Expand Down Expand Up @@ -157,8 +164,20 @@ def _create_batches(self, limit: Optional[int] = None) -> List[Dict[str, int]]:

return batches

def _get_cursor(self) -> TCursor:
cursor = self.collection.find(filter=self._filter_op)
def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
"""Get a reading cursor for the collection.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
Returns:
Cursor: The cursor for the collection.
"""
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find(filter=filter_op)
if self._sort_op:
cursor = cursor.sort(self._sort_op)

Expand All @@ -174,31 +193,37 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:

return data

def _get_all_batches(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
def _get_all_batches(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[TDataItem]:
"""Load all documents from the collection in parallel batches.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The maximum number of documents to load.
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
batches = self._create_batches(limit)
cursor = self._get_cursor()
batches = self._create_batches(limit=limit)
cursor = self._get_cursor(filter_=filter_)

for batch in batches:
yield self._run_batch(cursor=cursor, batch=batch)

def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[TDataItem]:
"""Load documents from the collection in parallel.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
for document in self._get_all_batches(limit):
for document in self._get_all_batches(limit=limit, filter_=filter_):
yield document


Expand All @@ -208,11 +233,14 @@ class CollectionArrowLoader(CollectionLoader):
Apache Arrow for data processing.
"""

def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]:
def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[Any]:
"""
Load documents from the collection in Apache Arrow format.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
Yields:
Expand All @@ -225,9 +253,11 @@ def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]:
None, codec_options=self.collection.codec_options
)

cursor = self.collection.find_raw_batches(
self._filter_op, batch_size=self.chunk_size
)
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find_raw_batches(filter_, batch_size=self.chunk_size)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore

Expand All @@ -246,9 +276,21 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel):
Apache Arrow for data processing.
"""

def _get_cursor(self) -> TCursor:
def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
"""Get a reading cursor for the collection.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
Returns:
Cursor: The cursor for the collection.
"""
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find_raw_batches(
filter=self._filter_op, batch_size=self.chunk_size
filter=filter_op, batch_size=self.chunk_size
)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore
Expand Down Expand Up @@ -276,6 +318,7 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:
def collection_documents(
client: TMongoClient,
collection: TCollection,
filter_: Dict[str, Any],
incremental: Optional[dlt.sources.incremental[Any]] = None,
parallel: bool = False,
limit: Optional[int] = None,
Expand All @@ -289,6 +332,7 @@ def collection_documents(
Args:
client (MongoClient): The PyMongo client `pymongo.MongoClient` instance.
collection (Collection): The collection `pymongo.collection.Collection` to load.
filter_ (Dict[str, Any]): The filter to apply to the collection.
incremental (Optional[dlt.sources.incremental[Any]]): The incremental configuration.
parallel (bool): Option to enable parallel loading for the collection. Default is False.
limit (Optional[int]): The maximum number of documents to load.
Expand All @@ -315,7 +359,7 @@ def collection_documents(
loader = LoaderClass(
client, collection, incremental=incremental, chunk_size=chunk_size
)
for data in loader.load_documents(limit=limit):
for data in loader.load_documents(limit=limit, filter_=filter_):
yield data


Expand Down Expand Up @@ -377,6 +421,27 @@ def client_from_credentials(connection_url: str) -> TMongoClient:
return client


def _raise_if_intersection(filter1: Dict[str, Any], filter2: Dict[str, Any]) -> None:
"""
Raise an exception, if the given filters'
fields are intersecting.
Args:
filter1 (Dict[str, Any]): The first filter.
filter2 (Dict[str, Any]): The second filter.
"""
field_inter = filter1.keys() & filter2.keys()
for field in field_inter:
if filter1[field].keys() & filter2[field].keys():
str_repr = str({field: filter1[field]})
raise ValueError(
(
f"Filtering operator {str_repr} is already used by the "
"incremental and can't be used in the filter."
)
)


@configspec
class MongoDbCollectionConfiguration(BaseConfiguration):
incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg]
Expand Down
48 changes: 48 additions & 0 deletions tests/mongodb/test_mongodb_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest import mock

import dlt
from dlt.pipeline.exceptions import PipelineStepFailed

from sources.mongodb import mongodb, mongodb_collection
from sources.mongodb_pipeline import (
Expand Down Expand Up @@ -356,3 +357,50 @@ def test_arrow_types(destination_name):

info = pipeline.run(res, table_name="types_test")
assert info.loads_ids != []


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_filter(destination_name):
"""
The field `runtime` is not set in some movies,
thus incremental will not work. However, adding
an explicit filter_, which says to consider
only documents with `runtime`, makes it work.
"""
pipeline = dlt.pipeline(
pipeline_name="mongodb_test",
destination=destination_name,
dataset_name="mongodb_test_data",
full_refresh=True,
)
movies = mongodb_collection(
collection="movies",
incremental=dlt.sources.incremental("runtime", initial_value=500),
filter_={"runtime": {"$exists": True}},
)
pipeline.run(movies)

table_counts = load_table_counts(pipeline, "movies")
assert table_counts["movies"] == 23


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_filter_intersect(destination_name):
"""
Check that using in the filter_ fields that
are used by incremental is not allowed.
"""
pipeline = dlt.pipeline(
pipeline_name="mongodb_test",
destination=destination_name,
dataset_name="mongodb_test_data",
full_refresh=True,
)
movies = mongodb_collection(
collection="movies",
incremental=dlt.sources.incremental("runtime", initial_value=20),
filter_={"runtime": {"$gte": 20}},
)

with pytest.raises(PipelineStepFailed):
pipeline.run(movies)

0 comments on commit 69a1c5e

Please sign in to comment.