Skip to content

Commit

Permalink
feat(python): support sqlalchemy/pandas backed write_database (#7322)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Mar 3, 2023
1 parent 5d0258f commit 5989465
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
47 changes: 38 additions & 9 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2791,14 +2791,11 @@ def write_database(
table_name: str,
connection_uri: str,
*,
mode: DbWriteMode = "create",
engine: DbWriteEngine = "adbc",
if_exists: DbWriteMode = "fail",
engine: DbWriteEngine = "sqlalchemy",
) -> None:
"""
Write a polars frame to an SQL database.
ADBC
Currently this can only connect to sqlite and postgres.
Write a polars frame to a database.
Parameters
----------
Expand All @@ -2808,21 +2805,53 @@ def write_database(
Connection uri, for example
* "postgresql://username:password@server:port/database"
mode : {'append', 'create'}
if_exists : {'append', 'replace', 'fail'}
The insert mode.
'create' will create a new database table.
'replace' will create a new database table, overwriting an existing one.
'append' will append to an existing table.
engine : {'adbc'}
'fail' will fail if table already exists.
engine : {'sqlalchemy', 'adbc'}
Select the engine used for writing the data.
"""
from polars.io.database import _open_adbc_connection

if engine == "adbc":
if if_exists == "fail":
raise ValueError("'if_exists' not yet supported with engine ADBC")
elif if_exists == "replace":
mode = "create"
elif if_exists == "append":
mode = "append"
else:
raise ValueError(
f"Value for 'if_exists'={if_exists} was unexpected. "
f"Choose one of: {'fail', 'replace', 'append'}."
)
with _open_adbc_connection(connection_uri) as conn:
cursor = conn.cursor()
cursor.adbc_ingest(table_name, self.to_arrow(), mode)
cursor.close()
conn.commit()
elif engine == "sqlalchemy":
if parse_version(pd.__version__) < parse_version("1.5"):
raise ModuleNotFoundError(
f"Writing with engine 'sqlalchemy' requires Pandas 1.5.x or higher, found Pandas {pd.__version__}."
)
try:
from sqlalchemy import create_engine
except ImportError as err:
raise ImportError(
"'sqlalchemy' not found. Install polars with 'pip install polars[sqlalchemy]'."
) from err

engine = create_engine(connection_uri)

# this conversion to pandas as zero-copy
# so we can utilize their sql utils for free
self.to_pandas(use_pyarrow_extension_array=True).to_sql(
name=table_name, con=engine, if_exists=if_exists
)

else:
raise ValueError(f"'engine' {engine} is not supported.")

Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/internals/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
TransferEncoding: TypeAlias = Literal["hex", "base64"]
CorrelationMethod: TypeAlias = Literal["pearson", "spearman"]
DbReadEngine: TypeAlias = Literal["adbc", "connectorx"]
DbWriteEngine: TypeAlias = Literal["adbc"]
DbWriteMode: TypeAlias = Literal["create", "append"]
DbWriteEngine: TypeAlias = Literal["sqlalchemy", "adbc"]
DbWriteMode: TypeAlias = Literal["replace", "append", "fail"]

# type signature for allowed frame init
FrameInitTypes: TypeAlias = "Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | pli.Series] | Sequence[Any] | np.ndarray[Any, Any] | pa.Table | pd.DataFrame | pli.Series"
4 changes: 3 additions & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ xlsx2csv = ["xlsx2csv >= 0.8.0"]
deltalake = ["deltalake >= 0.6.2"]
timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"]
matplotlib = ["matplotlib"]
sqlalchemy = ["sqlalchemy", "pandas"]
all = [
"polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,deltalake,timezone,matplotlib]",
"polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,deltalake,timezone,matplotlib,sqlalchemy]",
]

[tool.mypy]
Expand Down Expand Up @@ -83,6 +84,7 @@ module = [
"xlsxwriter.utility",
"xlsxwriter.worksheet",
"zoneinfo",
"sqlalchemy",
]
ignore_missing_imports = true

Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/io/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ def test_write_database(
sample_df.write_database(
table_name="test_data",
connection_uri=f"sqlite:///{test_db}",
mode="create",
if_exists="replace",
engine=engine,
)

if mode == "append":
sample_df.write_database(
table_name="test_data",
connection_uri=f"sqlite:///{test_db}",
mode="append",
if_exists="append",
engine=engine,
)
sample_df = pl.concat([sample_df, sample_df])
Expand Down

0 comments on commit 5989465

Please sign in to comment.