Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support trim with numeric and variant types #69

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: cast trim input to varchar implicitly
  • Loading branch information
seruman committed Apr 16, 2024
commit bda6fc3b298247f8eb21cf79ec34f931df795c8f
2 changes: 2 additions & 0 deletions fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def _execute(
.transform(transforms.tag)
.transform(transforms.semi_structured_types)
.transform(transforms.try_parse_json)
# NOTE: trim_cast_varchar must be before json_extract_cast_as_varchar
.transform(transforms.trim_cast_varchar)
# indices_to_json_extract must be before regex_substr
.transform(transforms.indices_to_json_extract)
.transform(transforms.json_extract_cast_as_varchar)
Expand Down
15 changes: 15 additions & 0 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,21 @@ def timestamp_ntz_ns(expression: exp.Expression) -> exp.Expression:
return expression


def trim_cast_varchar(expression: exp.Expression) -> exp.Expression:
"""Snowflake's TRIM casts input to VARCHAR implicitly."""

if not (isinstance(expression, exp.Trim)):
return expression

operand = expression.this
if isinstance(operand, exp.Cast) and operand.to.this in [exp.DataType.Type.VARCHAR, exp.DataType.Type.TEXT]:
return expression

return exp.Trim(
this=exp.Cast(this=operand, to=exp.DataType(this=exp.DataType.Type.VARCHAR, nested=False, prefix=False))
)


def try_parse_json(expression: exp.Expression) -> exp.Expression:
"""Convert TRY_PARSE_JSON() to TRY_CAST(... as JSON).

Expand Down
18 changes: 18 additions & 0 deletions tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,24 @@ def test_transactions(conn: snowflake.connector.SnowflakeConnection):
assert cur.fetchall() == [("Statement executed successfully.",)]


def test_trim_cast_varchar(conn: snowflake.connector.SnowflakeConnection):
with conn.cursor() as cur:
cur.execute("create or replace table trim_cast_varchar(id number, name varchar);")
cur.execute("insert into trim_cast_varchar(id, name) values (1, ' name 1 '), (2, 'name2 ');")
cur.execute("select trim(id), trim(name) from trim_cast_varchar;")

assert cur.fetchall() == [("1", "name 1"), ("2", "name2")]

with conn.cursor() as cur:
cur.execute("create or replace table trim_cast_varchar_variant_field(data variant);")
cur.execute(
"""insert into trim_cast_varchar_variant_field(data) values ('{"k1": " v11 "}'),('{"k1": 21}');"""
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails for me on a real Snowflake instance with:

ProgrammingError: 002023 (22000): SQL compilation error:
Expression type does not match column data type, expecting VARIANT but got VARCHAR(18) for column DATA

Am guessing we need to convert it to variant first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad again, updated.

As on #78 tests seems to fail due to -guessing- sqlalchemy version
#78 (comment)

cur.execute("select trim(data:k1) from trim_cast_varchar_variant_field;")

assert cur.fetchall() == [("v11",), ("21",)]


def test_unquoted_identifiers_are_upper_cased(dcur: snowflake.connector.cursor.SnowflakeCursor):
dcur.execute("create table customers (id int, first_name varchar, last_name varchar)")
dcur.execute("insert into customers values (1, 'Jenny', 'P')")
Expand Down
18 changes: 18 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
to_decimal,
to_timestamp,
to_timestamp_ntz,
trim_cast_varchar,
try_parse_json,
try_to_decimal,
upper_case_unquoted_identifiers,
Expand Down Expand Up @@ -770,6 +771,23 @@ def test_try_parse_json() -> None:
)


def test_trim_cast_varchar() -> None:
assert (
sqlglot.parse_one("SELECT TRIM(col) FROM table1").transform(trim_cast_varchar).sql(dialect="duckdb")
== "SELECT TRIM(CAST(col AS TEXT)) FROM table1"
)

assert (
sqlglot.parse_one("SELECT TRIM(col::varchar) FROM table1").transform(trim_cast_varchar).sql(dialect="duckdb")
== "SELECT TRIM(CAST(col AS TEXT)) FROM table1"
)

assert (
sqlglot.parse_one("SELECT TRIM(col::TEXT) FROM table1").transform(trim_cast_varchar).sql(dialect="duckdb")
== "SELECT TRIM(CAST(col AS TEXT)) FROM table1"
)


def test_upper_case_unquoted_identifiers() -> None:
assert (
sqlglot.parse_one("select name, name as fname from table1").transform(upper_case_unquoted_identifiers).sql()
Expand Down