Skip to content

Commit

Permalink
chore: bump sqlglot 19.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
tekumara committed Dec 26, 2023
1 parent e027af4 commit 9177a7f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 42 deletions.
63 changes: 24 additions & 39 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import cast

import snowflake.connector
import sqlglot
from sqlglot import exp

Expand Down Expand Up @@ -140,7 +139,7 @@ def extract_text_length(expression: exp.Expression) -> exp.Expression:
for dt in expression.find_all(exp.DataType):
if dt.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.TEXT):
col_name = dt.parent and dt.parent.this and dt.parent.this.this
if dt_size := dt.find(exp.DataTypeSize):
if dt_size := dt.find(exp.DataTypeParam):
size = (
isinstance(dt_size.this, exp.Literal)
and isinstance(dt_size.this.this, str)
Expand Down Expand Up @@ -311,79 +310,65 @@ def parse_json(expression: exp.Expression) -> exp.Expression:
def regex_replace(expression: exp.Expression) -> exp.Expression:
"""Transform regex_replace expressions from snowflake to duckdb."""

if (
isinstance(expression, exp.Anonymous)
and isinstance(expression.this, str)
and expression.this.upper() == "REGEXP_REPLACE"
):
expressions = expression.expressions

if len(expressions) > 3:
if isinstance(expression, exp.RegexpReplace) and isinstance(expression.expression, exp.Literal):
if len(expression.args) > 3:
# see https://docs.snowflake.com/en/sql-reference/functions/regexp_replace
raise NotImplementedError(
"REGEXP_REPLACE with additional parameters (eg: <position>, <occurrence>, <parameters>) not supported"
)

# pattern: snowflake requires escaping backslashes in single-quoted string constants, but duckdb doesn't
# see https://docs.snowflake.com/en/sql-reference/functions-regexp#label-regexp-escape-character-caveats
expressions[1].args["this"] = expressions[1].this.replace("\\\\", "\\")
expression.args["expression"] = exp.Literal(
this=expression.expression.this.replace("\\\\", "\\"), is_string=True
)

if len(expressions) == 2:
if not expression.args.get("replacement"):
# if no replacement string, the snowflake default is ''
expressions.append(exp.Literal(this="", is_string=True))
expression.args["replacement"] = exp.Literal(this="", is_string=True)

# snowflake regex replacements are global
expressions.append(exp.Literal(this="g", is_string=True))
expression.args["modifiers"] = exp.Literal(this="g", is_string=True)

return expression


def regex_substr(expression: exp.Expression) -> exp.Expression:
"""Transform regex_substr expressions from snowflake to duckdb."""
"""Transform regex_substr expressions from snowflake to duckdb.
if (
isinstance(expression, exp.Anonymous)
and isinstance(expression.this, str)
and expression.this.upper() == "REGEXP_SUBSTR"
):
expressions = expression.expressions

if len(expressions) < 2:
raise snowflake.connector.errors.ProgrammingError(
msg=f"SQL compilation error:\nnot enough arguments for function [{expression.sql()}], expected 2, got {len(expressions)}", # noqa: E501
errno=938,
sqlstate="22023",
)
See https://docs.snowflake.com/en/sql-reference/functions/regexp_substr
"""

subject = expressions[0]
if isinstance(expression, exp.RegexpExtract):
subject = expression.this

# pattern: snowflake requires escaping backslashes in single-quoted string constants, but duckdb doesn't
# see https://docs.snowflake.com/en/sql-reference/functions-regexp#label-regexp-escape-character-caveats
pattern = expressions[1]
pattern = expression.expression
pattern.args["this"] = pattern.this.replace("\\\\", "\\")

# number of characters from the beginning of the string where the function starts searching for matches
try:
position = expressions[2]
except IndexError:
position = expression.args["position"]
except KeyError:
position = exp.Literal(this="1", is_string=False)

# which occurrence of the pattern to match
try:
occurrence = expressions[3]
except IndexError:
occurrence = expression.args["occurrence"]
except KeyError:
occurrence = exp.Literal(this="1", is_string=False)

try:
regex_parameters_value = str(expressions[4].this)
regex_parameters_value = str(expression.args["parameters"].this)
# 'e' parameter doesn't make sense for duckdb
regex_parameters = exp.Literal(this=regex_parameters_value.replace("e", ""), is_string=True)
except IndexError:
except KeyError:
regex_parameters = exp.Literal(is_string=True)

try:
group_num = expressions[5]
except IndexError:
group_num = expression.args["group"]
except KeyError:
if isinstance(regex_parameters.this, str) and "e" in regex_parameters.this:
group_num = exp.Literal(this="1", is_string=False)
else:
Expand Down Expand Up @@ -555,7 +540,7 @@ def timestamp_ntz_ns(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.DataType)
and expression.this == exp.DataType.Type.TIMESTAMP
and exp.DataTypeSize(this=exp.Literal(this="9", is_string=False)) in expression.expressions
and exp.DataTypeParam(this=exp.Literal(this="9", is_string=False)) in expression.expressions
):
new = expression.copy()
del new.args["expressions"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"duckdb~=0.8.0",
"pyarrow",
"snowflake-connector-python",
"sqlglot~=16.8.1",
"sqlglot~=19.5.1",
]

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ def test_tags_noop(cur: snowflake.connector.cursor.SnowflakeCursor):

def test_timestamp(cur: snowflake.connector.cursor.SnowflakeCursor):
cur.execute("SELECT to_timestamp(0)")
# NB: duckdb~=0.9 returns a datetime with utc timezone
assert cur.fetchall() == [(datetime.datetime(1970, 1, 1, 0, 0),)]


Expand Down
6 changes: 4 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_parse_json() -> None:
assert (
sqlglot.parse_one("""insert into table1 (name) select parse_json('{"first":"foo", "last":"bar"}')""")
.transform(parse_json)
.sql()
.sql(dialect="duckdb")
== """INSERT INTO table1 (name) SELECT JSON('{"first":"foo", "last":"bar"}')"""
)

Expand All @@ -147,7 +147,9 @@ def test_regex_replace() -> None:

def test_regex_substr() -> None:
assert (
sqlglot.parse_one("SELECT regexp_substr(string1, 'the\\\\W+\\\\w+')").transform(regex_substr).sql()
sqlglot.parse_one("SELECT regexp_substr(string1, 'the\\\\W+\\\\w+')", read="snowflake")
.transform(regex_substr)
.sql()
== "SELECT REGEXP_EXTRACT_ALL(string1[1 : ], 'the\\W+\\w+', 0, '')[1]"
)

Expand Down

0 comments on commit 9177a7f

Please sign in to comment.