diff --git a/sources/mongodb/helpers.py b/sources/mongodb/helpers.py index 7a53f337c..fe5dcc69c 100644 --- a/sources/mongodb/helpers.py +++ b/sources/mongodb/helpers.py @@ -29,6 +29,13 @@ TCollection = Any TCursor = Any +try: + import pymongoarrow # type: ignore + + PYMONGOARROW_AVAILABLE = True +except ImportError: + PYMONGOARROW_AVAILABLE = False + class CollectionLoader: def __init__( @@ -345,15 +352,21 @@ def collection_documents( Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. """ + if data_item_format == "arrow" and not PYMONGOARROW_AVAILABLE: + dlt.common.logger.warn( + "'pymongoarrow' is not installed; falling back to standard MongoDB CollectionLoader." + ) + data_item_format = "object" + if parallel: if data_item_format == "arrow": LoaderClass = CollectionArrowLoaderParallel - elif data_item_format == "object": + else: LoaderClass = CollectionLoaderParallel # type: ignore else: if data_item_format == "arrow": LoaderClass = CollectionArrowLoader # type: ignore - elif data_item_format == "object": + else: LoaderClass = CollectionLoader # type: ignore loader = LoaderClass( diff --git a/sources/mongodb/requirements.txt b/sources/mongodb/requirements.txt index 45ac0bc3d..5240a44e2 100644 --- a/sources/mongodb/requirements.txt +++ b/sources/mongodb/requirements.txt @@ -1,3 +1,2 @@ -pymongo>=4.3.3 -pymongoarrow>=1.3.0 +pymongo>=3 dlt>=0.5.1 diff --git a/sources/sql_database_pipeline.py b/sources/sql_database_pipeline.py index 86e21845d..605895a48 100644 --- a/sources/sql_database_pipeline.py +++ b/sources/sql_database_pipeline.py @@ -338,9 +338,9 @@ def specify_columns_to_load() -> None: ) # Columns can be specified per table in env var (json array) or in `.dlt/config.toml` - os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = ( - '["rfam_acc", "description"]' - ) + os.environ[ + "SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS" + ] = '["rfam_acc", "description"]' sql_alchemy_source = sql_database( "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", diff --git a/tests/mongodb/test_mongodb_source.py b/tests/mongodb/test_mongodb_source.py index b385a9a6e..00993ba59 100644 --- a/tests/mongodb/test_mongodb_source.py +++ b/tests/mongodb/test_mongodb_source.py @@ -1,5 +1,8 @@ -import bson import json +from unittest import mock + +import bson +import dlt import pyarrow import pytest from pendulum import DateTime, timezone @@ -404,3 +407,26 @@ def test_filter_intersect(destination_name): with pytest.raises(PipelineStepFailed): pipeline.run(movies) + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("data_item_format", ["object", "arrow"]) +def test_mongodb_without_pymongoarrow( + destination_name: str, data_item_format: str +) -> None: + with mock.patch.dict("sys.modules", {"pymongoarrow": None}): + pipeline = dlt.pipeline( + pipeline_name="test_mongodb_without_pymongoarrow", + destination=destination_name, + dataset_name="test_mongodb_without_pymongoarrow_data", + full_refresh=True, + ) + + comments = mongodb_collection( + collection="comments", limit=10, data_item_format=data_item_format + ) + load_info = pipeline.run(comments) + + assert load_info.loads_ids != [] + table_counts = load_table_counts(pipeline, "comments") + assert table_counts["comments"] == 10 diff --git a/tests/rest_api/test_rest_api_source_processed.py b/tests/rest_api/test_rest_api_source_processed.py index cc04c27a6..85e07f9e6 100644 --- a/tests/rest_api/test_rest_api_source_processed.py +++ b/tests/rest_api/test_rest_api_source_processed.py @@ -40,7 +40,6 @@ def test_rest_api_source_filtered(mock_api_server) -> None: def test_rest_api_source_exclude_columns(mock_api_server) -> None: - def exclude_columns(columns: List[str]) -> Callable: def pop_columns(resource: DltResource) -> DltResource: for col in columns: @@ -73,7 +72,6 @@ def pop_columns(resource: DltResource) -> DltResource: def test_rest_api_source_anonymize_columns(mock_api_server) -> None: - def anonymize_columns(columns: List[str]) -> Callable: def empty_columns(resource: DltResource) -> DltResource: for col in columns: @@ -106,7 +104,6 @@ def empty_columns(resource: DltResource) -> DltResource: def test_rest_api_source_map(mock_api_server) -> None: - def lower_title(row): row["title"] = row["title"].lower() return row @@ -133,7 +130,6 @@ def lower_title(row): def test_rest_api_source_filter_and_map(mock_api_server) -> None: - def id_by_10(row): row["id"] = row["id"] * 10 return row @@ -211,7 +207,6 @@ def test_rest_api_source_filtered_child(mock_api_server) -> None: def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None: - def extend_body(row): row["body"] = f"{row['_posts_title']} - {row['body']}" return row diff --git a/tests/sql_database/test_sql_database_source.py b/tests/sql_database/test_sql_database_source.py index abecde9a8..9d578ff7e 100644 --- a/tests/sql_database/test_sql_database_source.py +++ b/tests/sql_database/test_sql_database_source.py @@ -97,9 +97,9 @@ def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None: def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: # set the credentials per table name - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = ( - sql_source_db.engine.url.render_as_string(False) - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS" + ] = sql_source_db.engine.url.render_as_string(False) table = sql_table(table="chat_message", schema=sql_source_db.schema) assert table.name == "chat_message" assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] @@ -119,9 +119,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: assert len(list(table)) == 10 # make it fail on cursor - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at_x" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at_x" table = sql_table(table="chat_message", schema=sql_source_db.schema) with pytest.raises(ResourceExtractionError) as ext_ex: len(list(table)) @@ -130,9 +130,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: # set the credentials per table name - os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = ( - sql_source_db.engine.url.render_as_string(False) - ) + os.environ[ + "SOURCES__SQL_DATABASE__CREDENTIALS" + ] = sql_source_db.engine.url.render_as_string(False) # applies to both sql table and sql database table = sql_table(table="chat_message", schema=sql_source_db.schema) assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] @@ -155,9 +155,9 @@ def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: assert len(list(database)) == 10 # make it fail on cursor - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at_x" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at_x" table = sql_table(table="chat_message", schema=sql_source_db.schema) with pytest.raises(ResourceExtractionError) as ext_ex: len(list(table)) @@ -275,9 +275,9 @@ def test_load_sql_table_incremental( """Run pipeline twice. Insert more rows after first run and ensure only those rows are stored after the second run. """ - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at" pipeline = make_pipeline(destination_name) tables = ["chat_message"]