Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ v0.47.0 - Persistent listeners, schema builders, and performance polish
rules, so ``optimize_joins``, ``optimize_predicates``, and
``simplify_expressions`` now disable only their matching steps instead of
always running the full default pipeline.
* Passing a sqlglot ``Dialect`` class to EXPLAIN builders or
``StatementConfig.dialect`` now resolves to the correct dialect name.
* Avoided parser round-trips for simple builder identifiers and MERGE JSON
source construction while preserving rendered SQL.
* Deferred temporal version-generator registration until temporal builder APIs
are used. Code that hand-builds ``exp.Version`` nodes should call
``sqlspec.builder.register_version_generators()`` before rendering them.
* Routed async pool teardown through the base config lifecycle path so
``on_pool_destroy`` and ``on_pool_destroying`` fire consistently across async
adapters.
Expand Down
3 changes: 0 additions & 3 deletions sqlspec/builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,3 @@
"sql",
"to_expression",
)

# Register temporal query SQL generators on module import
register_version_generators()
5 changes: 4 additions & 1 deletion sqlspec/builder/_explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any

from mypy_extensions import trait
from sqlglot import Dialect
from typing_extensions import Self

from sqlspec.core import SQL, StatementConfig
Expand Down Expand Up @@ -58,7 +59,9 @@ def normalize_dialect_name(dialect: "DialectType | None") -> str | None:
return None
if isinstance(dialect, str):
return dialect.lower()
return dialect.__class__.__name__.lower()
if isinstance(dialect, type) and issubclass(dialect, Dialect):
return dialect.__name__.lower()
return type(dialect).__name__.lower()


def build_postgres_explain(statement_sql: str, options: "ExplainOptions") -> str:
Expand Down
27 changes: 27 additions & 0 deletions sqlspec/builder/_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""SQLGlot generator registration helpers."""

from typing import TYPE_CHECKING, cast

if TYPE_CHECKING:
from sqlglot.generator import Generator

__all__ = ("invalidate_generator_dispatch",)


def invalidate_generator_dispatch(*generator_classes: "type[object]") -> None:
"""Clear SQLGlot generator dispatch entries after transform registration.

SQLGlot caches generator dispatch lookups. SQLSpec mutates ``TRANSFORMS``
on existing generator classes for compiled sqlglot compatibility, so those
cache entries need to be invalidated after registration.

Args:
*generator_classes: SQLGlot generator classes whose dispatch caches should be cleared.
"""
try:
from sqlglot.generator import _DISPATCH_CACHE # pyright: ignore[reportPrivateUsage,reportPrivateImportUsage]
except ImportError:
return

for generator_class in generator_classes:
_DISPATCH_CACHE.pop(cast("type[Generator]", generator_class), None)
3 changes: 3 additions & 0 deletions sqlspec/builder/_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sqlspec.builder._base import BuiltQuery, QueryBuilder
from sqlspec.builder._parsing_utils import parse_table_expression
from sqlspec.builder._temporal import register_version_generators
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.type_guards import has_expression_and_parameters, has_expression_and_sql, has_parameter_builder

Expand Down Expand Up @@ -114,6 +115,8 @@ def _apply_lateral_modifier(join_expr: exp.Join) -> None:
def _attach_as_of_version(
table_expr: exp.Expr, alias: str | None, as_of: Any, as_of_type: str | None = None
) -> exp.Expr:
register_version_generators()

inner_table = table_expr.copy()
target_alias = alias

Expand Down
60 changes: 43 additions & 17 deletions sqlspec/builder/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from itertools import starmap
from typing import TYPE_CHECKING, Any, cast

import sqlglot as sg
from mypy_extensions import trait
from sqlglot import exp
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -237,14 +236,32 @@ def _create_postgres_json_source(
alias_name = alias or "src"
recordset_alias = f"{alias_name}_data"

column_type_spec = ", ".join([f"{col} {self._infer_postgres_type(sample_values.get(col))}" for col in columns])
column_selects = ", ".join(columns)
from_sql = (
f"SELECT {column_selects} FROM jsonb_to_recordset(:{json_param_name}::jsonb) AS "
f"{recordset_alias}({column_type_spec})"
recordset_table = exp.Table(
this=exp.Anonymous(
this="jsonb_to_recordset",
expressions=[
exp.Cast(
this=exp.Placeholder(this=json_param_name), to=exp.DataType.build("JSONB", dialect="postgres")
)
],
),
alias=exp.TableAlias(
this=exp.to_identifier(recordset_alias),
columns=[
exp.ColumnDef(
this=exp.to_identifier(column),
kind=exp.DataType.build(
self._infer_postgres_type(sample_values.get(column)), dialect="postgres"
),
)
for column in columns
],
),
)

parsed = sg.parse_one(from_sql, dialect="postgres")
parsed = exp.Select(
expressions=[exp.column(column) for column in columns], from_=exp.From(this=recordset_table)
)
return exp.Subquery(
this=parsed,
alias=exp.TableAlias(
Expand All @@ -265,17 +282,26 @@ def _create_oracle_json_source(
if value is not None and column not in sample_values:
sample_values[column] = value

json_columns = [
f"{column} {self._infer_oracle_type(sample_values.get(column))} PATH '$.{column}'" for column in columns
]

alias_name = alias or "src"
column_selects = ", ".join(columns)
columns_clause = ", ".join(json_columns)

from_sql = f"SELECT {column_selects} FROM JSON_TABLE(:{json_param_name}, '$[*]' COLUMNS ({columns_clause}))"

parsed = sg.parse_one(from_sql, dialect="oracle")
json_table = exp.Table(
this=exp.JSONTable(
this=exp.Placeholder(this=json_param_name),
path=exp.Literal.string("$[*]"),
schema=exp.JSONSchema(
expressions=[
exp.JSONColumnDef(
this=exp.to_identifier(column),
kind=exp.DataType.build(
self._infer_oracle_type(sample_values.get(column)), dialect="oracle"
),
path=exp.Literal.string(f"$.{column}"),
)
for column in columns
]
),
)
)
parsed = exp.Select(expressions=[exp.column(column) for column in columns], from_=exp.From(this=json_table))
return exp.Subquery(this=parsed, alias=exp.TableAlias(this=exp.to_identifier(alias_name)))

def _infer_postgres_type(self, value: "Any") -> str:
Expand Down
76 changes: 75 additions & 1 deletion sqlspec/builder/_parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import contextlib
import re
from typing import TYPE_CHECKING, Any, Final

from sqlglot import exp, maybe_parse
Expand Down Expand Up @@ -33,6 +34,47 @@
)

ALIAS_PARTS_EXPECTED_COUNT = 2
QUALIFIED_IDENTIFIER_PARTS = 2
_SIMPLE_IDENTIFIER_RE: Final["re.Pattern[str]"] = re.compile(
r"^[A-Za-z_][A-Za-z0-9_$]*(?:\.[A-Za-z_][A-Za-z0-9_$]*){0,2}$"
)
_BARE_KEYWORDS: Final[frozenset[str]] = frozenset({
"all",
"and",
"any",
"asc",
"between",
"case",
"current_date",
"current_time",
"current_timestamp",
"current_user",
"default",
"delete",
"desc",
"distinct",
"end",
"exists",
"false",
"from",
"in",
"insert",
"interval",
"is",
"like",
"localtime",
"localtimestamp",
"not",
"null",
"or",
"select",
"session_user",
"some",
"true",
"update",
"user",
"where",
})
_PARAMETER_VALIDATOR = ParameterValidator()


Expand Down Expand Up @@ -69,6 +111,23 @@ def _merge_sql_parameters(sql_obj: Any, builder: Any) -> None:
builder.add_parameter(param_value, name=param_name)


def _is_simple_identifier(value: str) -> bool:
stripped = value.strip()
if not _SIMPLE_IDENTIFIER_RE.fullmatch(stripped):
return False
return "." in stripped or stripped.lower() not in _BARE_KEYWORDS


def _simple_column_expression(value: str) -> exp.Column:
parts = value.strip().split(".")
identifiers = [exp.Identifier(this=part, quoted=False) for part in parts]
if len(parts) == 1:
return exp.Column(this=identifiers[0])
if len(parts) == QUALIFIED_IDENTIFIER_PARTS:
return exp.Column(this=identifiers[1], table=identifiers[0])
return exp.Column(this=identifiers[2], table=identifiers[1], db=identifiers[0])


def parse_column_expression(column_input: str | exp.Expr | Any, builder: Any | None = None) -> exp.Expr:
"""Parse a column input that might be a complex expression.

Expand All @@ -92,6 +151,8 @@ def parse_column_expression(column_input: str | exp.Expr | Any, builder: Any | N
return column_input

if isinstance(column_input, str):
if _is_simple_identifier(column_input):
return _simple_column_expression(column_input)
return exp.maybe_parse(column_input) or exp.column(column_input)

if has_expression_and_sql(column_input):
Expand Down Expand Up @@ -126,6 +187,9 @@ def parse_table_expression(
base_table, alias = parts
return exp.to_table(base_table, alias=alias, dialect=dialect)

if _is_simple_identifier(table_input):
return exp.to_table(table_input, alias=explicit_alias, dialect=dialect)

with contextlib.suppress(Exception):
parsed: exp.Expr | None = exp.maybe_parse(f"SELECT * FROM {table_input}", dialect=dialect)
if isinstance(parsed, exp.Select):
Expand Down Expand Up @@ -157,7 +221,17 @@ def parse_order_expression(order_input: str | exp.Expr) -> exp.Expr:
if isinstance(order_input, exp.Expr):
return order_input

parsed = maybe_parse(str(order_input), into=exp.Ordered)
order_value = str(order_input)
parts = order_value.rsplit(None, 1)
if len(parts) == ALIAS_PARTS_EXPECTED_COUNT and parts[1].lower() in {"asc", "desc"}:
base, direction = parts
if _is_simple_identifier(base):
column_expr = _simple_column_expression(base)
if direction.lower() == "desc":
return exp.Ordered(this=column_expr, desc=True, nulls_first=False)
return exp.Ordered(this=column_expr, desc=False, nulls_first=True)

parsed = maybe_parse(order_value, into=exp.Ordered)
if parsed:
return parsed

Expand Down
Loading
Loading