diff --git a/docs/changelog.rst b/docs/changelog.rst index 61dbe12ae..58aaa5b1f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -114,6 +114,13 @@ v0.47.0 - Persistent listeners, schema builders, and performance polish rules, so ``optimize_joins``, ``optimize_predicates``, and ``simplify_expressions`` now disable only their matching steps instead of always running the full default pipeline. +* Passing a sqlglot ``Dialect`` class to EXPLAIN builders or + ``StatementConfig.dialect`` now resolves to the correct dialect name. +* Avoided parser round-trips for simple builder identifiers and MERGE JSON + source construction while preserving rendered SQL. +* Deferred temporal version-generator registration until temporal builder APIs + are used. Code that hand-builds ``exp.Version`` nodes should call + ``sqlspec.builder.register_version_generators()`` before rendering them. * Routed async pool teardown through the base config lifecycle path so ``on_pool_destroy`` and ``on_pool_destroying`` fire consistently across async adapters. diff --git a/sqlspec/builder/__init__.py b/sqlspec/builder/__init__.py index e0707988a..a919f42b9 100644 --- a/sqlspec/builder/__init__.py +++ b/sqlspec/builder/__init__.py @@ -176,6 +176,3 @@ "sql", "to_expression", ) - -# Register temporal query SQL generators on module import -register_version_generators() diff --git a/sqlspec/builder/_explain.py b/sqlspec/builder/_explain.py index 8a28bd4e8..2257aa7ec 100644 --- a/sqlspec/builder/_explain.py +++ b/sqlspec/builder/_explain.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any from mypy_extensions import trait +from sqlglot import Dialect from typing_extensions import Self from sqlspec.core import SQL, StatementConfig @@ -58,7 +59,9 @@ def normalize_dialect_name(dialect: "DialectType | None") -> str | None: return None if isinstance(dialect, str): return dialect.lower() - return dialect.__class__.__name__.lower() + if isinstance(dialect, type) and issubclass(dialect, Dialect): + return dialect.__name__.lower() + return type(dialect).__name__.lower() def build_postgres_explain(statement_sql: str, options: "ExplainOptions") -> str: diff --git a/sqlspec/builder/_generation.py b/sqlspec/builder/_generation.py new file mode 100644 index 000000000..89c1e6442 --- /dev/null +++ b/sqlspec/builder/_generation.py @@ -0,0 +1,27 @@ +"""SQLGlot generator registration helpers.""" + +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from sqlglot.generator import Generator + +__all__ = ("invalidate_generator_dispatch",) + + +def invalidate_generator_dispatch(*generator_classes: "type[object]") -> None: + """Clear SQLGlot generator dispatch entries after transform registration. + + SQLGlot caches generator dispatch lookups. SQLSpec mutates ``TRANSFORMS`` + on existing generator classes for compiled sqlglot compatibility, so those + cache entries need to be invalidated after registration. + + Args: + *generator_classes: SQLGlot generator classes whose dispatch caches should be cleared. + """ + try: + from sqlglot.generator import _DISPATCH_CACHE # pyright: ignore[reportPrivateUsage,reportPrivateImportUsage] + except ImportError: + return + + for generator_class in generator_classes: + _DISPATCH_CACHE.pop(cast("type[Generator]", generator_class), None) diff --git a/sqlspec/builder/_join.py b/sqlspec/builder/_join.py index 7570234a8..79f88e39e 100644 --- a/sqlspec/builder/_join.py +++ b/sqlspec/builder/_join.py @@ -12,6 +12,7 @@ from sqlspec.builder._base import BuiltQuery, QueryBuilder from sqlspec.builder._parsing_utils import parse_table_expression +from sqlspec.builder._temporal import register_version_generators from sqlspec.exceptions import SQLBuilderError from sqlspec.utils.type_guards import has_expression_and_parameters, has_expression_and_sql, has_parameter_builder @@ -114,6 +115,8 @@ def _apply_lateral_modifier(join_expr: exp.Join) -> None: def _attach_as_of_version( table_expr: exp.Expr, alias: str | None, as_of: Any, as_of_type: str | None = None ) -> exp.Expr: + register_version_generators() + inner_table = table_expr.copy() target_alias = alias diff --git a/sqlspec/builder/_merge.py b/sqlspec/builder/_merge.py index 0380d4a1f..aac0bdbc7 100644 --- a/sqlspec/builder/_merge.py +++ b/sqlspec/builder/_merge.py @@ -11,7 +11,6 @@ from itertools import starmap from typing import TYPE_CHECKING, Any, cast -import sqlglot as sg from mypy_extensions import trait from sqlglot import exp from sqlglot.errors import ParseError @@ -237,14 +236,32 @@ def _create_postgres_json_source( alias_name = alias or "src" recordset_alias = f"{alias_name}_data" - column_type_spec = ", ".join([f"{col} {self._infer_postgres_type(sample_values.get(col))}" for col in columns]) - column_selects = ", ".join(columns) - from_sql = ( - f"SELECT {column_selects} FROM jsonb_to_recordset(:{json_param_name}::jsonb) AS " - f"{recordset_alias}({column_type_spec})" + recordset_table = exp.Table( + this=exp.Anonymous( + this="jsonb_to_recordset", + expressions=[ + exp.Cast( + this=exp.Placeholder(this=json_param_name), to=exp.DataType.build("JSONB", dialect="postgres") + ) + ], + ), + alias=exp.TableAlias( + this=exp.to_identifier(recordset_alias), + columns=[ + exp.ColumnDef( + this=exp.to_identifier(column), + kind=exp.DataType.build( + self._infer_postgres_type(sample_values.get(column)), dialect="postgres" + ), + ) + for column in columns + ], + ), ) - parsed = sg.parse_one(from_sql, dialect="postgres") + parsed = exp.Select( + expressions=[exp.column(column) for column in columns], from_=exp.From(this=recordset_table) + ) return exp.Subquery( this=parsed, alias=exp.TableAlias( @@ -265,17 +282,26 @@ def _create_oracle_json_source( if value is not None and column not in sample_values: sample_values[column] = value - json_columns = [ - f"{column} {self._infer_oracle_type(sample_values.get(column))} PATH '$.{column}'" for column in columns - ] - alias_name = alias or "src" - column_selects = ", ".join(columns) - columns_clause = ", ".join(json_columns) - - from_sql = f"SELECT {column_selects} FROM JSON_TABLE(:{json_param_name}, '$[*]' COLUMNS ({columns_clause}))" - - parsed = sg.parse_one(from_sql, dialect="oracle") + json_table = exp.Table( + this=exp.JSONTable( + this=exp.Placeholder(this=json_param_name), + path=exp.Literal.string("$[*]"), + schema=exp.JSONSchema( + expressions=[ + exp.JSONColumnDef( + this=exp.to_identifier(column), + kind=exp.DataType.build( + self._infer_oracle_type(sample_values.get(column)), dialect="oracle" + ), + path=exp.Literal.string(f"$.{column}"), + ) + for column in columns + ] + ), + ) + ) + parsed = exp.Select(expressions=[exp.column(column) for column in columns], from_=exp.From(this=json_table)) return exp.Subquery(this=parsed, alias=exp.TableAlias(this=exp.to_identifier(alias_name))) def _infer_postgres_type(self, value: "Any") -> str: diff --git a/sqlspec/builder/_parsing_utils.py b/sqlspec/builder/_parsing_utils.py index 4ee882d5d..43a7326ca 100644 --- a/sqlspec/builder/_parsing_utils.py +++ b/sqlspec/builder/_parsing_utils.py @@ -5,6 +5,7 @@ """ import contextlib +import re from typing import TYPE_CHECKING, Any, Final from sqlglot import exp, maybe_parse @@ -33,6 +34,47 @@ ) ALIAS_PARTS_EXPECTED_COUNT = 2 +QUALIFIED_IDENTIFIER_PARTS = 2 +_SIMPLE_IDENTIFIER_RE: Final["re.Pattern[str]"] = re.compile( + r"^[A-Za-z_][A-Za-z0-9_$]*(?:\.[A-Za-z_][A-Za-z0-9_$]*){0,2}$" +) +_BARE_KEYWORDS: Final[frozenset[str]] = frozenset({ + "all", + "and", + "any", + "asc", + "between", + "case", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "default", + "delete", + "desc", + "distinct", + "end", + "exists", + "false", + "from", + "in", + "insert", + "interval", + "is", + "like", + "localtime", + "localtimestamp", + "not", + "null", + "or", + "select", + "session_user", + "some", + "true", + "update", + "user", + "where", +}) _PARAMETER_VALIDATOR = ParameterValidator() @@ -69,6 +111,23 @@ def _merge_sql_parameters(sql_obj: Any, builder: Any) -> None: builder.add_parameter(param_value, name=param_name) +def _is_simple_identifier(value: str) -> bool: + stripped = value.strip() + if not _SIMPLE_IDENTIFIER_RE.fullmatch(stripped): + return False + return "." in stripped or stripped.lower() not in _BARE_KEYWORDS + + +def _simple_column_expression(value: str) -> exp.Column: + parts = value.strip().split(".") + identifiers = [exp.Identifier(this=part, quoted=False) for part in parts] + if len(parts) == 1: + return exp.Column(this=identifiers[0]) + if len(parts) == QUALIFIED_IDENTIFIER_PARTS: + return exp.Column(this=identifiers[1], table=identifiers[0]) + return exp.Column(this=identifiers[2], table=identifiers[1], db=identifiers[0]) + + def parse_column_expression(column_input: str | exp.Expr | Any, builder: Any | None = None) -> exp.Expr: """Parse a column input that might be a complex expression. @@ -92,6 +151,8 @@ def parse_column_expression(column_input: str | exp.Expr | Any, builder: Any | N return column_input if isinstance(column_input, str): + if _is_simple_identifier(column_input): + return _simple_column_expression(column_input) return exp.maybe_parse(column_input) or exp.column(column_input) if has_expression_and_sql(column_input): @@ -126,6 +187,9 @@ def parse_table_expression( base_table, alias = parts return exp.to_table(base_table, alias=alias, dialect=dialect) + if _is_simple_identifier(table_input): + return exp.to_table(table_input, alias=explicit_alias, dialect=dialect) + with contextlib.suppress(Exception): parsed: exp.Expr | None = exp.maybe_parse(f"SELECT * FROM {table_input}", dialect=dialect) if isinstance(parsed, exp.Select): @@ -157,7 +221,17 @@ def parse_order_expression(order_input: str | exp.Expr) -> exp.Expr: if isinstance(order_input, exp.Expr): return order_input - parsed = maybe_parse(str(order_input), into=exp.Ordered) + order_value = str(order_input) + parts = order_value.rsplit(None, 1) + if len(parts) == ALIAS_PARTS_EXPECTED_COUNT and parts[1].lower() in {"asc", "desc"}: + base, direction = parts + if _is_simple_identifier(base): + column_expr = _simple_column_expression(base) + if direction.lower() == "desc": + return exp.Ordered(this=column_expr, desc=True, nulls_first=False) + return exp.Ordered(this=column_expr, desc=False, nulls_first=True) + + parsed = maybe_parse(order_value, into=exp.Ordered) if parsed: return parsed diff --git a/sqlspec/builder/_temporal.py b/sqlspec/builder/_temporal.py index 52c203f56..2a2735428 100644 --- a/sqlspec/builder/_temporal.py +++ b/sqlspec/builder/_temporal.py @@ -11,39 +11,37 @@ - CockroachDB (Postgres): table AS OF SYSTEM TIME timestamp """ +from typing import TYPE_CHECKING + from sqlglot import exp -from sqlglot.dialects.bigquery import BigQuery -from sqlglot.dialects.duckdb import DuckDB -from sqlglot.dialects.oracle import Oracle -from sqlglot.dialects.postgres import Postgres -from sqlglot.dialects.snowflake import Snowflake -from sqlglot.generator import ( - _DISPATCH_CACHE, # pyright: ignore[reportPrivateUsage] - Generator, -) -from sqlglot.generators.bigquery import BigQueryGenerator -from sqlglot.generators.duckdb import DuckDBGenerator -from sqlglot.generators.oracle import OracleGenerator -from sqlglot.generators.postgres import PostgresGenerator -from sqlglot.generators.snowflake import SnowflakeGenerator + +from sqlspec.builder._generation import invalidate_generator_dispatch + +if TYPE_CHECKING: + from sqlglot.generator import Generator + from sqlglot.generators.bigquery import BigQueryGenerator + from sqlglot.generators.duckdb import DuckDBGenerator + from sqlglot.generators.oracle import OracleGenerator + from sqlglot.generators.postgres import PostgresGenerator + from sqlglot.generators.snowflake import SnowflakeGenerator __all__ = ("create_temporal_table", "register_version_generators") -def _oracle_version_sql(self: OracleGenerator, expression: exp.Version) -> str: +def _oracle_version_sql(self: "OracleGenerator", expression: exp.Version) -> str: """Oracle: AS OF TIMESTAMP timestamp or AS OF SCN scn.""" expr = self.sql(expression, "expression") this = expression.name or "TIMESTAMP" return f"AS OF {this} {expr}" -def _bigquery_version_sql(self: BigQueryGenerator, expression: exp.Version) -> str: +def _bigquery_version_sql(self: "BigQueryGenerator", expression: exp.Version) -> str: """BigQuery: FOR SYSTEM_TIME AS OF timestamp.""" expr = self.sql(expression, "expression") return f"FOR SYSTEM_TIME AS OF {expr}" -def _snowflake_version_sql(self: SnowflakeGenerator, expression: exp.Version) -> str: +def _snowflake_version_sql(self: "SnowflakeGenerator", expression: exp.Version) -> str: """Snowflake: AT (TIMESTAMP => timestamp) or BEFORE (TIMESTAMP => ...). AS OF is mapped to AT, and BEFORE is supported for point-before queries. @@ -56,19 +54,19 @@ def _snowflake_version_sql(self: SnowflakeGenerator, expression: exp.Version) -> return f"AT ({this} => {expr})" -def _duckdb_version_sql(self: DuckDBGenerator, expression: exp.Version) -> str: +def _duckdb_version_sql(self: "DuckDBGenerator", expression: exp.Version) -> str: """DuckDB: AT (TIMESTAMP => timestamp).""" expr = self.sql(expression, "expression") return f"AT (TIMESTAMP => {expr})" -def _cockroachdb_version_sql(self: PostgresGenerator, expression: exp.Version) -> str: +def _cockroachdb_version_sql(self: "PostgresGenerator", expression: exp.Version) -> str: """CockroachDB (via Postgres dialect): AS OF SYSTEM TIME timestamp.""" expr = self.sql(expression, "expression") return f"AS OF SYSTEM TIME {expr}" -def _default_version_sql(self: Generator, expression: exp.Version) -> str: +def _default_version_sql(self: "Generator", expression: exp.Version) -> str: """Default: AS OF SYSTEM TIME timestamp (CockroachDB style). When no dialect is specified, we default to CockroachDB/Postgres style @@ -92,6 +90,8 @@ def create_temporal_table( Returns: Table expression with version clause that generates dialect-specific SQL. """ + register_version_generators() + if isinstance(table, str): table_expr = exp.to_table(table) elif isinstance(table, exp.Table): @@ -129,6 +129,13 @@ def register_version_generators() -> None: if _VERSION_GENERATORS_REGISTERED: return + from sqlglot.dialects.bigquery import BigQuery + from sqlglot.dialects.duckdb import DuckDB + from sqlglot.dialects.oracle import Oracle + from sqlglot.dialects.postgres import Postgres + from sqlglot.dialects.snowflake import Snowflake + from sqlglot.generator import Generator + Generator.TRANSFORMS[exp.Version] = _default_version_sql BigQuery.Generator.TRANSFORMS[exp.Version] = _bigquery_version_sql @@ -137,16 +144,8 @@ def register_version_generators() -> None: DuckDB.Generator.TRANSFORMS[exp.Version] = _duckdb_version_sql Postgres.Generator.TRANSFORMS[exp.Version] = _cockroachdb_version_sql - # Invalidate sqlglot's per-class dispatch cache so new TRANSFORMS entries - # are picked up by the next Generator instantiation. - for gen_cls in ( - Generator, - BigQuery.Generator, - Oracle.Generator, - Snowflake.Generator, - DuckDB.Generator, - Postgres.Generator, - ): - _DISPATCH_CACHE.pop(gen_cls, None) + invalidate_generator_dispatch( + Generator, BigQuery.Generator, Oracle.Generator, Snowflake.Generator, DuckDB.Generator, Postgres.Generator + ) _VERSION_GENERATORS_REGISTERED = True diff --git a/sqlspec/builder/_vector_distance.py b/sqlspec/builder/_vector_distance.py index 2b466ac63..f5f628a7c 100644 --- a/sqlspec/builder/_vector_distance.py +++ b/sqlspec/builder/_vector_distance.py @@ -7,6 +7,8 @@ from sqlglot import exp +from sqlspec.builder._generation import invalidate_generator_dispatch + if TYPE_CHECKING: from sqlglot.generator import Generator @@ -271,19 +273,8 @@ def _register_with_sqlglot() -> None: _register_operator_transform(BigQuery.Generator.TRANSFORMS, _operator_sql_bigquery) _register_operator_transform(DuckDB.Generator.TRANSFORMS, _operator_sql_duckdb) - # sqlglot caches the dispatch table (built from TRANSFORMS) per Generator class - # in _DISPATCH_CACHE. We must invalidate stale entries so the next instantiation - # picks up our new Operator transforms. - from sqlglot.generator import _DISPATCH_CACHE # pyright: ignore[reportPrivateUsage] - - for gen_cls in ( - Generator, - Postgres.Generator, - MySQL.Generator, - Oracle.Generator, - BigQuery.Generator, - DuckDB.Generator, - ): - _DISPATCH_CACHE.pop(gen_cls, None) + invalidate_generator_dispatch( + Generator, Postgres.Generator, MySQL.Generator, Oracle.Generator, BigQuery.Generator, DuckDB.Generator + ) _SQLGLOT_VECTOR_DISTANCE_REGISTERED = True diff --git a/sqlspec/core/statement.py b/sqlspec/core/statement.py index a181f8c94..356bdef9a 100644 --- a/sqlspec/core/statement.py +++ b/sqlspec/core/statement.py @@ -7,7 +7,7 @@ import sqlglot from mypy_extensions import mypyc_attr -from sqlglot import exp +from sqlglot import Dialect, exp from sqlglot.errors import ParseError import sqlspec.exceptions @@ -464,7 +464,9 @@ def _normalize_dialect(self, dialect: "DialectType") -> "str | None": return None if isinstance(dialect, str): return dialect - return dialect.__class__.__name__.lower() + if isinstance(dialect, type) and issubclass(dialect, Dialect): + return dialect.__name__.lower() + return type(dialect).__name__.lower() def _get_raw_sql(self) -> str: """Return raw SQL, materializing deferred expression SQL when needed.""" diff --git a/sqlspec/dialects/postgres/_generators.py b/sqlspec/dialects/postgres/_generators.py index 8b8beccef..73c286da1 100644 --- a/sqlspec/dialects/postgres/_generators.py +++ b/sqlspec/dialects/postgres/_generators.py @@ -9,9 +9,9 @@ from sqlglot import exp from sqlglot.dialects.postgres import Postgres -from sqlglot.generator import _DISPATCH_CACHE # pyright: ignore[reportPrivateUsage] from sqlglot.generators.postgres import PostgresGenerator +from sqlspec.builder._generation import invalidate_generator_dispatch from sqlspec.builder._vector_distance import ( is_vector_distance_expression, render_vector_distance_postgres, @@ -50,9 +50,7 @@ def _postgres_extension_operator_sql(generator: PostgresGenerator, expression: e # patch on the base class to support both environments consistently. PostgresGenerator.TRANSFORMS[exp.Operator] = _postgres_extension_operator_sql -# Invalidate sqlglot's per-class dispatch cache so the patched Operator -# transform is picked up by the next Generator instantiation. -_DISPATCH_CACHE.pop(PostgresGenerator, None) +invalidate_generator_dispatch(PostgresGenerator) PGVectorGenerator = PostgresGenerator # pyright: ignore[reportAssignmentType] ParadeDBGenerator = PostgresGenerator # pyright: ignore[reportAssignmentType] diff --git a/sqlspec/dialects/postgres/_operators.py b/sqlspec/dialects/postgres/_operators.py index a26b1852d..b8d97a8bd 100644 --- a/sqlspec/dialects/postgres/_operators.py +++ b/sqlspec/dialects/postgres/_operators.py @@ -10,7 +10,7 @@ from sqlglot import exp from sqlglot.parsers.postgres import PostgresParser -from sqlglot.tokens import TokenType # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage] +from sqlglot.tokenizer_core import TokenType __all__ = ( "PARADEDB_OPERATOR_TOKENS", diff --git a/sqlspec/dialects/spanner/_generators.py b/sqlspec/dialects/spanner/_generators.py index 47dd3dcbd..2f1e62ff8 100644 --- a/sqlspec/dialects/spanner/_generators.py +++ b/sqlspec/dialects/spanner/_generators.py @@ -19,10 +19,11 @@ from typing import Any, Final, cast from sqlglot import exp -from sqlglot.generator import _DISPATCH_CACHE # pyright: ignore[reportPrivateUsage] from sqlglot.generators.bigquery import BigQueryGenerator from sqlglot.generators.postgres import PostgresGenerator +from sqlspec.builder._generation import invalidate_generator_dispatch + __all__ = ("SpangresGenerator", "SpannerGenerator") _TTL_MIN_COMPONENTS = 2 @@ -284,7 +285,7 @@ def _bq_create_transform(self: Any, expression: exp.Create) -> str: BigQueryGenerator.TRANSFORMS[exp.Properties] = _bq_properties_transform BigQueryGenerator.TRANSFORMS[exp.Create] = _bq_create_transform -_DISPATCH_CACHE.pop(BigQueryGenerator, None) +invalidate_generator_dispatch(BigQueryGenerator) SpannerGenerator = BigQueryGenerator # pyright: ignore[reportAssignmentType] @@ -315,6 +316,6 @@ def _pg_properties_transform(self: Any, expression: exp.Properties) -> str: PostgresGenerator.TRANSFORMS[exp.Property] = _pg_property_transform PostgresGenerator.TRANSFORMS[exp.Properties] = _pg_properties_transform -_DISPATCH_CACHE.pop(PostgresGenerator, None) +invalidate_generator_dispatch(PostgresGenerator) SpangresGenerator = PostgresGenerator # pyright: ignore[reportAssignmentType] diff --git a/tests/unit/builder/test_explain.py b/tests/unit/builder/test_explain.py index ec13b4131..34760e4f5 100644 --- a/tests/unit/builder/test_explain.py +++ b/tests/unit/builder/test_explain.py @@ -20,6 +20,7 @@ import pytest from sqlglot import exp +from sqlglot.dialects.postgres import Postgres from sqlspec.builder import ( Delete, @@ -1146,6 +1147,28 @@ def testnormalize_dialect_name_mixed_case_string(): assert normalize_dialect_name("PostgreSQL") == "postgresql" +@pytest.mark.parametrize( + ("dialect", "expected"), + [ + ("postgres", "postgres"), + ("POSTGRES", "postgres"), + (Postgres, "postgres"), + (Postgres(), "postgres"), + (None, None), + ], +) +def test_normalize_dialect_name_handles_postgres_dialect_inputs(dialect, expected): + """Test normalize_dialect_name handles dialect strings, classes, instances, and None.""" + assert normalize_dialect_name(dialect) == expected + + +def test_build_explain_sql_uses_postgres_builder_for_dialect_class(): + """Dialect classes should dispatch to the concrete EXPLAIN builder.""" + result = build_explain_sql("SELECT 1", ExplainOptions(analyze=True), Postgres) + + assert result == "EXPLAIN (ANALYZE) SELECT 1" + + # ----------------------------------------------------------------------------- # QueryBuilder Integration Tests - ExplainMixin # ----------------------------------------------------------------------------- diff --git a/tests/unit/builder/test_generation.py b/tests/unit/builder/test_generation.py new file mode 100644 index 000000000..0e99bb255 --- /dev/null +++ b/tests/unit/builder/test_generation.py @@ -0,0 +1,56 @@ +"""Tests for SQLGlot generator registration helpers.""" + +import builtins +from typing import Any + +from sqlglot import exp +from sqlglot.generator import Generator + +from sqlspec.builder._generation import invalidate_generator_dispatch + + +def test_invalidate_generator_dispatch_refreshes_transform_cache() -> None: + """Invalidation should make a fresh generator see TRANSFORMS mutations.""" + sentinel = object() + original_transform = Generator.TRANSFORMS.get(exp.Null, sentinel) + + try: + assert Generator().sql(exp.Null()) == "NULL" + + Generator.TRANSFORMS[exp.Null] = lambda _generator, _expression: "SQLSPEC_NULL" + assert Generator().sql(exp.Null()) == "NULL" + + invalidate_generator_dispatch(Generator) + + assert Generator().sql(exp.Null()) == "SQLSPEC_NULL" + finally: + if original_transform is sentinel: + Generator.TRANSFORMS.pop(exp.Null, None) + else: + Generator.TRANSFORMS[exp.Null] = original_transform # type: ignore[assignment] + invalidate_generator_dispatch(Generator) + + +def test_invalidate_generator_dispatch_unknown_class_is_noop() -> None: + """Unknown classes should be accepted as no-op invalidation targets.""" + + class UnknownGenerator: + pass + + invalidate_generator_dispatch(UnknownGenerator) + + +def test_invalidate_generator_dispatch_missing_cache_is_noop(monkeypatch) -> None: + """Missing private sqlglot cache imports should degrade silently.""" + original_import = builtins.__import__ + + def raising_import( + name: str, globals_: Any = None, locals_: Any = None, fromlist: tuple[str, ...] = (), level: int = 0 + ): + if name == "sqlglot.generator" and "_DISPATCH_CACHE" in fromlist: + raise ImportError("cache moved") + return original_import(name, globals_, locals_, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", raising_import) + + invalidate_generator_dispatch(Generator) diff --git a/tests/unit/builder/test_merge.py b/tests/unit/builder/test_merge.py index 326fbe07b..959116796 100644 --- a/tests/unit/builder/test_merge.py +++ b/tests/unit/builder/test_merge.py @@ -2,6 +2,7 @@ import pytest +import sqlspec.builder._merge as merge_module from sqlspec import sql from sqlspec.builder import Merge from sqlspec.exceptions import DialectNotSupportedError, SQLBuilderError @@ -571,6 +572,134 @@ def test_merge_using_single_dict_postgres_dialect() -> None: assert "json_data" in stmt.parameters +def test_merge_postgres_json_source_golden_sql_and_parameters() -> None: + """PostgreSQL JSON source SQL and parameters should stay byte-identical.""" + data = [{"id": 1, "name": "Widget", "price": 12.5, "active": True, "meta": {"x": 1}, "missing": None}] + query = ( + sql + .merge(dialect="postgres") + .into("products", alias="t") + .using(data, alias="src") + .on("t.id = src.id") + .when_matched_then_update(name="src.name", price="src.price") + .when_not_matched_then_insert(columns=["id", "name", "price", "active", "meta", "missing"]) + ) + + stmt = query.build() + source_expr = query.get_expression().args["using"] # type: ignore[index,union-attr] + + assert stmt.parameters == {"json_data": data} + assert source_expr.sql(dialect="postgres") == ( + "(SELECT id, name, price, active, meta, missing FROM " + "JSONB_TO_RECORDSET(CAST(%(json_data)s AS JSONB)) AS src_data(id INT, name TEXT, " + "price DOUBLE PRECISION, active BOOLEAN, meta JSONB, missing DECIMAL)) AS " + "src(id, name, price, active, meta, missing)" + ) + assert ( + stmt.sql + == """MERGE INTO products AS "t" +USING ( + SELECT + "id", + "name", + "price", + "active", + "meta", + "missing" + FROM JSONB_TO_RECORDSET(CAST(%(json_data)s AS JSONB)) AS "src_data"("id" INT, "name" TEXT, "price" DOUBLE PRECISION, "active" BOOLEAN, "meta" JSONB, "missing" DECIMAL) +) AS "src"("id", "name", "price", "active", "meta", "missing") +ON ( + "t"."id" = "src"."id" +) +WHEN MATCHED THEN UPDATE SET + "name" = "src"."name", + "price" = "src"."price" +WHEN NOT MATCHED THEN INSERT ("id", "name", "price", "active", "meta", "missing") VALUES ( + "src"."id", + "src"."name", + "src"."price", + "src"."active", + "src"."meta", + "src"."missing" +)""" + ) + + +def test_merge_oracle_json_source_golden_sql_and_parameters() -> None: + """Oracle JSON source SQL and parameters should stay byte-identical.""" + data = [{"id": 1, "name": "Widget", "price": 12.5, "active": True, "meta": {"x": 1}, "missing": None}] + query = ( + sql + .merge(dialect="oracle") + .into("products", alias="t") + .using(data, alias="src") + .on("t.id = src.id") + .when_matched_then_update(name="src.name", price="src.price") + .when_not_matched_then_insert(columns=["id", "name", "price", "active", "meta", "missing"]) + ) + + stmt = query.build() + source_expr = query.get_expression().args["using"] # type: ignore[index,union-attr] + + assert stmt.parameters == { + "json_payload": '[{"id":1,"name":"Widget","price":12.5,"active":true,"meta":{"x":1},"missing":null}]' + } + assert source_expr.sql(dialect="oracle") == ( + "(SELECT id, name, price, active, meta, missing FROM JSON_TABLE(:json_payload, '$[*]' " + "COLUMNS(id NUMBER PATH '$.id', name VARCHAR2(4000) PATH '$.name', price NUMBER PATH '$.price', " + "active NUMBER(1) PATH '$.active', meta JSON PATH '$.meta', missing VARCHAR2(4000) PATH '$.missing'))) src" + ) + assert ( + stmt.sql + == """MERGE INTO products t +USING ( + SELECT + id, + name, + price, + active, + meta, + missing + FROM JSON_TABLE(:json_payload, '$[*]' COLUMNS( + id NUMBER PATH '$.id', + name VARCHAR2(4000) PATH '$.name', + price NUMBER PATH '$.price', + active NUMBER(1) PATH '$.active', + meta JSON PATH '$.meta', + missing VARCHAR2(4000) PATH '$.missing' + )) +) src +ON ( + t.id = src.id +) +WHEN MATCHED THEN UPDATE SET + name = src.name, + price = src.price +WHEN NOT MATCHED THEN INSERT (id, name, price, active, meta, missing) VALUES (src.id, src.name, src.price, src.active, src.meta, src.missing)""" + ) + + +@pytest.mark.parametrize("dialect", ["postgres", "oracle"]) +def test_merge_json_source_avoids_parse_one(monkeypatch: pytest.MonkeyPatch, dialect: str) -> None: + """JSON source construction should not round-trip through sqlglot.parse_one.""" + data = [{"id": 1, "name": "Widget"}] + builder = sql.merge(dialect=dialect) + + def fail_parse_one(*args, **kwargs): # type: ignore[no-untyped-def] + raise AssertionError("parse_one should not be called for JSON source construction") + + merge_sqlglot = getattr(merge_module, "sg", None) + if merge_sqlglot is not None: + monkeypatch.setattr(merge_sqlglot, "parse_one", fail_parse_one) + + if dialect == "postgres": + source_expr = builder._create_postgres_json_source(data, ["id", "name"], True, "src") # pyright: ignore[reportPrivateUsage] + else: + source_expr = builder._create_oracle_json_source(data, ["id", "name"], "src") # pyright: ignore[reportPrivateUsage] + + assert "SELECT" in source_expr.sql(dialect=dialect) + + def test_merge_using_empty_list_raises_error() -> None: """Test MERGE with empty list raises appropriate error.""" with pytest.raises(SQLBuilderError, match="Cannot create USING clause from empty list"): diff --git a/tests/unit/builder/test_parsing_utils.py b/tests/unit/builder/test_parsing_utils.py index 3e3307298..fa9ddb958 100644 --- a/tests/unit/builder/test_parsing_utils.py +++ b/tests/unit/builder/test_parsing_utils.py @@ -5,10 +5,20 @@ was added to fix QueryBuilder parameter handling issues. """ +import contextlib + +import pytest from sqlglot import exp +from sqlglot.errors import ParseError +import sqlspec.builder._parsing_utils as parsing_utils from sqlspec import sql -from sqlspec.builder import parse_column_expression, parse_condition_expression +from sqlspec.builder import ( + parse_column_expression, + parse_condition_expression, + parse_order_expression, + parse_table_expression, +) from sqlspec.core import get_cache @@ -113,6 +123,117 @@ def test_parse_column_expression_qualified() -> None: assert isinstance(expr, exp.Expr) +_PARSING_CORPUS = [ + "name", + "users.name", + "db.users.name", + "col$1", + "_x", + "true", + "null", + "count", + "user", + "MAX(price)", + "name AS n", + "price * 2", + '"Quoted".col', + "name DESC", + "users.name asc", + "COUNT(*) DESC", + "name nulls first", +] + + +def _parse_column_oracle(value: str) -> exp.Expr: + return exp.maybe_parse(value) or exp.column(value) + + +def _parse_order_oracle(value: str) -> exp.Expr: + parsed = parsing_utils.maybe_parse(str(value), into=exp.Ordered) + if parsed: + return parsed + return _parse_column_oracle(value) + + +def _assert_matches_oracle(value: str, parser, oracle) -> None: + with contextlib.suppress(ParseError): + expected = oracle(value) + actual = parser(value) + assert type(actual) is type(expected) + assert actual.sql() == expected.sql() + return + + with pytest.raises(ParseError): + parser(value) + + +@pytest.mark.parametrize("value", _PARSING_CORPUS) +def test_parse_column_expression_matches_parser_oracle(value: str) -> None: + """Column parser fast paths must preserve the previous parser-backed shape.""" + _assert_matches_oracle(value, parse_column_expression, _parse_column_oracle) + + +@pytest.mark.parametrize("value", _PARSING_CORPUS) +def test_parse_order_expression_matches_parser_oracle(value: str) -> None: + """Order parser fast paths must preserve the previous parser-backed shape.""" + _assert_matches_oracle(value, parse_order_expression, _parse_order_oracle) + + +def test_parse_column_expression_simple_identifier_avoids_parser(monkeypatch: pytest.MonkeyPatch) -> None: + """Simple column identifiers should avoid sqlglot parsing.""" + calls = 0 + original = exp.maybe_parse + + def recorder(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal calls + calls += 1 + return original(*args, **kwargs) + + monkeypatch.setattr(exp, "maybe_parse", recorder) + + expr = parse_column_expression("users.name") + + assert calls == 0 + assert expr.sql() == "users.name" + + +def test_parse_table_expression_simple_identifier_avoids_parser(monkeypatch: pytest.MonkeyPatch) -> None: + """Simple table identifiers should avoid SELECT-wrapper parsing.""" + calls = 0 + original = exp.maybe_parse + + def recorder(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal calls + calls += 1 + return original(*args, **kwargs) + + monkeypatch.setattr(exp, "maybe_parse", recorder) + + expr = parse_table_expression("schema.users", explicit_alias="u") + + assert calls == 0 + assert expr.sql() == "schema.users AS u" + + +def test_parse_order_expression_directional_identifier_avoids_parser(monkeypatch: pytest.MonkeyPatch) -> None: + """Directional simple ORDER BY identifiers should avoid sqlglot parsing.""" + calls = 0 + original = parsing_utils.maybe_parse + + def recorder(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal calls + calls += 1 + return original(*args, **kwargs) + + monkeypatch.setattr(parsing_utils, "maybe_parse", recorder) + + expr = parse_order_expression("users.name asc") + + assert calls == 0 + assert isinstance(expr, exp.Ordered) + assert expr.sql() == "users.name ASC" + + def test_parse_column_expression_sqlglot_passthrough() -> None: """Test that parse_column_expression passes through SQLGlot expressions.""" original_expr = exp.column("test") diff --git a/tests/unit/builder/test_temporal.py b/tests/unit/builder/test_temporal.py index 963e029d9..a09f89580 100644 --- a/tests/unit/builder/test_temporal.py +++ b/tests/unit/builder/test_temporal.py @@ -1,6 +1,10 @@ import re +import subprocess +import sys -from sqlspec.builder import sql +from sqlglot import exp + +from sqlspec.builder import create_temporal_table, sql def normalize_sql(sql_str: str) -> str: @@ -121,3 +125,22 @@ def test_join_as_of_dialect_override() -> None: # AS keyword may or may not be present before alias assert "LEFT JOIN audit_log AS OF TIMESTAMP CAST('2023-01-01' AS TIMESTAMP)" in normalized assert "log ON orders.id = log.order_id" in normalized + + +def test_builder_import_does_not_eagerly_import_temporal_dialects() -> None: + """Importing sqlspec.builder should not load temporal dialect modules.""" + code = ( + "import sys; import sqlspec.builder; " + "loaded = [name for name in ('sqlglot.dialects.snowflake', 'sqlglot.dialects.oracle') if name in sys.modules]; " + "assert not loaded, loaded" + ) + result = subprocess.run([sys.executable, "-c", code], check=False, capture_output=True, text=True) + + assert result.returncode == 0, result.stderr + + +def test_create_temporal_table_registers_version_generators_lazily() -> None: + """Temporal API calls should register exp.Version rendering before SQL generation.""" + table_expr = create_temporal_table("users", exp.Literal.string("2024-01-01")) + + assert "AS OF TIMESTAMP '2024-01-01'" in normalize_sql(table_expr.sql(dialect="oracle")) diff --git a/tests/unit/core/test_statement.py b/tests/unit/core/test_statement.py index ab2d2ca71..ce42941ae 100644 --- a/tests/unit/core/test_statement.py +++ b/tests/unit/core/test_statement.py @@ -22,6 +22,7 @@ import pytest from sqlglot import expressions as exp +from sqlglot.dialects.postgres import Postgres import sqlspec.typing as public_typing from sqlspec.core import ( @@ -450,6 +451,16 @@ def test_sql_initialization_with_custom_config() -> None: assert stmt.statement_config.dialect == "sqlite" +@pytest.mark.parametrize( + ("dialect", "expected"), [("postgres", "postgres"), (Postgres, "postgres"), (Postgres(), "postgres"), (None, None)] +) +def test_sql_normalizes_postgres_dialect_inputs(dialect, expected) -> None: + """StatementConfig.dialect should accept sqlglot dialect names, classes, and instances.""" + stmt = SQL("SELECT 1", statement_config=StatementConfig(dialect=dialect)) + + assert stmt.dialect == expected + + def test_sql_initialization_from_sql_object() -> None: """Test SQL initialization from existing SQL object.""" original = SQL("SELECT * FROM users", id=1)