From e5ac9cfd518a7e3b0ecdcf70d5a2a7e784a412f6 Mon Sep 17 00:00:00 2001 From: aabhasr Date: Tue, 28 Apr 2026 16:41:52 +0530 Subject: [PATCH] feat: add SQL language support to parser MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds .sql file parsing via tree-sitter-sql (already in language pack): - CREATE TABLE/VIEW → Class nodes with extra["sql_kind"] - CREATE FUNCTION → Function nodes with extra["sql_kind"]="function" - CREATE PROCEDURE → regex fallback (grammar emits ERROR for this node) - FROM/JOIN table references → IMPORTS_FROM edges for impact analysis Includes fixture (tests/fixtures/sample.sql) and 8 TestSQLParsing tests. Co-Authored-By: Claude Sonnet 4.6 --- code_review_graph/parser.py | 201 ++++++++++++++++++++++++++++++++++++ tests/fixtures/sample.sql | 37 +++++++ tests/test_multilang.py | 56 ++++++++++ uv.lock | 1 + 4 files changed, 295 insertions(+) create mode 100644 tests/fixtures/sample.sql diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index f681263a..2d975b06 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -32,6 +32,15 @@ class CellInfo(NamedTuple): re.IGNORECASE, ) +# SQL keywords that can appear after FROM/JOIN but are NOT table names. +_SQL_KEYWORDS: frozenset[str] = frozenset({ + "SELECT", "WHERE", "GROUP", "ORDER", "HAVING", "LIMIT", "OFFSET", + "UNION", "INTERSECT", "EXCEPT", "AS", "ON", "USING", "SET", + "VALUES", "DEFAULT", "NULL", "TRUE", "FALSE", + "INNER", "OUTER", "LEFT", "RIGHT", "FULL", "CROSS", "NATURAL", + "LATERAL", "RECURSIVE", "ONLY", "WITH", +}) + logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -125,6 +134,7 @@ class EdgeInfo: ".res": "rescript", ".resi": "rescript", ".gd": "gdscript", + ".sql": "sql", } # Tree-sitter node type mappings per language @@ -172,6 +182,8 @@ class EdgeInfo: "zig": ["container_declaration"], "powershell": ["class_statement"], "julia": ["struct_definition", "abstract_definition"], + # SQL: CREATE TABLE / CREATE VIEW are handled via _parse_sql dispatch. + "sql": [], } _FUNCTION_TYPES: dict[str, list[str]] = { @@ -222,6 +234,8 @@ class EdgeInfo: "function_definition", "short_function_definition", ], + # SQL: CREATE FUNCTION / CREATE PROCEDURE handled via _parse_sql dispatch. + "sql": [], } _IMPORT_TYPES: dict[str, list[str]] = { @@ -262,6 +276,8 @@ class EdgeInfo: "powershell": [], # Julia: import/using are import_statement nodes. "julia": ["import_statement", "using_statement"], + # SQL: table references extracted as IMPORTS_FROM via _parse_sql dispatch. + "sql": [], } _CALL_TYPES: dict[str, list[str]] = { @@ -300,6 +316,8 @@ class EdgeInfo: "zig": ["call_expression", "builtin_call_expr"], "powershell": ["command_expression"], "julia": ["call_expression"], + # SQL: no call edges extracted (grammar too unreliable for procedure calls). + "sql": [], } # Patterns that indicate a test function @@ -682,6 +700,11 @@ def parse_bytes(self, path: Path, source: bytes) -> tuple[list[NodeInfo], list[E if language == "rescript": return self._parse_rescript(path, source) + # SQL: dedicated parser — tree-sitter for tables/views/functions + + # regex fallback for CREATE PROCEDURE (unsupported by the grammar). + if language == "sql": + return self._parse_sql(path, source) + parser = self._get_parser(language) if not parser: return [], [] @@ -1675,6 +1698,184 @@ def enclosing_module(off: int) -> Optional[str]: return nodes, edges + # ------------------------------------------------------------------ + # SQL parser + # ------------------------------------------------------------------ + + # Regex for CREATE PROCEDURE — tree-sitter SQL grammar emits an ERROR node + # for this statement, so we fall back to a regex scan. + _SQL_PROC_RE = re.compile( + r"CREATE\s+(?:OR\s+REPLACE\s+)?PROCEDURE\s+(\w+(?:\.\w+)*)", + re.IGNORECASE, + ) + + # Named DDL statements supported by tree-sitter-sql. + _SQL_DDL_NODE_TYPES = frozenset({ + "create_table", + "create_view", + "create_function", + }) + + def _parse_sql( + self, path: Path, source: bytes, + ) -> tuple[list[NodeInfo], list[EdgeInfo]]: + """Parse a `.sql` file. + + Extracts: + - Tables (CREATE TABLE) → Class nodes with extra["sql_kind"]="table" + - Views (CREATE VIEW) → Class nodes with extra["sql_kind"]="view" + - Functions (CREATE FUNCTION) → Function nodes with extra["sql_kind"]="function" + - Procedures (CREATE PROCEDURE, regex fallback) → Function nodes with + extra["sql_kind"]="procedure" + + Data dependencies (FROM/JOIN table references) are recorded as + IMPORTS_FROM edges so the impact-radius query can follow them. + """ + text = source.decode("utf-8", errors="replace") + file_path_str = str(path) + test_file = _is_test_file(file_path_str) + + nodes: list[NodeInfo] = [] + edges: list[EdgeInfo] = [] + + nodes.append(NodeInfo( + kind="File", + name=file_path_str, + file_path=file_path_str, + line_start=1, + line_end=text.count("\n") + 1, + language="sql", + is_test=test_file, + )) + + # --- tree-sitter pass --- + parser = self._get_parser("sql") + if parser: + tree = parser.parse(source) + self._walk_sql_tree( + tree.root_node, source, file_path_str, nodes, edges, + ) + + # --- regex fallback for CREATE PROCEDURE --- + for m in self._SQL_PROC_RE.finditer(text): + raw_name = m.group(1) + name = raw_name.split(".")[-1] # strip schema prefix + line = text[: m.start()].count("\n") + 1 + qualified = f"{file_path_str}::{name}" + nodes.append(NodeInfo( + kind="Function", + name=name, + file_path=file_path_str, + line_start=line, + line_end=line, + language="sql", + extra={"sql_kind": "procedure"}, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=file_path_str, + target=qualified, + file_path=file_path_str, + line=line, + )) + + # --- table-reference pass (FROM / JOIN targets) --- + seen_refs: set[str] = set() + for m in _SQL_TABLE_RE.finditer(text): + raw_ref = m.group(1).strip("`") + ref = raw_ref.split(".")[-1] # strip schema/db prefix + if ref and ref.upper() not in _SQL_KEYWORDS and ref not in seen_refs: + seen_refs.add(ref) + line = text[: m.start()].count("\n") + 1 + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path_str, + target=ref, + file_path=file_path_str, + line=line, + )) + + return nodes, edges + + def _walk_sql_tree( + self, + node, + source: bytes, + file_path_str: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + ) -> None: + """Recursively walk a tree-sitter SQL AST and extract DDL entities.""" + if node.type in self._SQL_DDL_NODE_TYPES: + self._extract_sql_ddl(node, source, file_path_str, nodes, edges) + return # don't recurse into the DDL body — no nested DDL expected + for child in node.children: + self._walk_sql_tree(child, source, file_path_str, nodes, edges) + + def _extract_sql_ddl( + self, + node, + source: bytes, + file_path_str: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + ) -> None: + """Extract a single CREATE TABLE / VIEW / FUNCTION DDL node.""" + node_type = node.type + line_start = node.start_point[0] + 1 + line_end = node.end_point[0] + 1 + + # Locate the identifier / object_reference child that holds the name. + name: Optional[str] = None + for child in node.children: + if child.type in ("identifier", "object_reference", "dotted_name"): + raw = source[child.start_byte: child.end_byte].decode("utf-8", errors="replace") + # Strip schema prefix (schema.name → name) + name = raw.strip("`\"").split(".")[-1] + break + # Some grammars nest: relation > object_reference > identifier + if child.type == "relation": + for gc in child.children: + if gc.type in ("object_reference", "identifier"): + raw = source[gc.start_byte: gc.end_byte].decode( + "utf-8", errors="replace", + ) + name = raw.strip("`\"").split(".")[-1] + break + if name: + break + + if not name: + return + + if node_type == "create_table": + kind = "Class" + sql_kind = "table" + elif node_type == "create_view": + kind = "Class" + sql_kind = "view" + else: # create_function + kind = "Function" + sql_kind = "function" + + qualified = f"{file_path_str}::{name}" + nodes.append(NodeInfo( + kind=kind, + name=name, + file_path=file_path_str, + line_start=line_start, + line_end=line_end, + language="sql", + extra={"sql_kind": sql_kind}, + )) + edges.append(EdgeInfo( + kind="CONTAINS", + source=file_path_str, + target=qualified, + file_path=file_path_str, + line=line_start, + )) + def _resolve_call_targets( self, nodes: list[NodeInfo], diff --git a/tests/fixtures/sample.sql b/tests/fixtures/sample.sql new file mode 100644 index 00000000..4dbf41a2 --- /dev/null +++ b/tests/fixtures/sample.sql @@ -0,0 +1,37 @@ +-- Sample SQL fixture for code-review-graph parser tests + +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE +); + +CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + total NUMERIC(10, 2), + created_at TIMESTAMP DEFAULT NOW() +); + +CREATE VIEW active_orders AS + SELECT o.id, u.name, o.total + FROM orders o + JOIN users u ON u.id = o.user_id + WHERE o.total > 0; + +CREATE FUNCTION get_user_total(p_user_id INTEGER) +RETURNS NUMERIC AS $$ + SELECT SUM(total) + FROM orders + WHERE user_id = p_user_id; +$$ LANGUAGE sql; + +CREATE OR REPLACE PROCEDURE archive_old_orders(cutoff_date DATE) +LANGUAGE plpgsql AS $$ +BEGIN + INSERT INTO orders_archive + SELECT * FROM orders WHERE created_at < cutoff_date; + + DELETE FROM orders WHERE created_at < cutoff_date; +END; +$$; diff --git a/tests/test_multilang.py b/tests/test_multilang.py index 9d45f434..1988f6b9 100644 --- a/tests/test_multilang.py +++ b/tests/test_multilang.py @@ -1916,3 +1916,59 @@ def test_resolver_is_idempotent(self, tmp_path): # Second run should find nothing new — all already resolved. assert second["calls_resolved"] == 0 assert second["imports_resolved"] == 0 + + +class TestSQLParsing: + def setup_method(self): + self.parser = CodeParser() + self.nodes, self.edges = self.parser.parse_file(FIXTURES / "sample.sql") + + def test_detects_language(self): + assert self.parser.detect_language(Path("schema.sql")) == "sql" + + def test_file_node(self): + file_nodes = [n for n in self.nodes if n.kind == "File"] + assert len(file_nodes) == 1 + assert file_nodes[0].language == "sql" + + def test_finds_tables(self): + tables = [n for n in self.nodes if n.kind == "Class" and n.extra.get("sql_kind") == "table"] + names = {t.name for t in tables} + assert "users" in names + assert "orders" in names + + def test_finds_view(self): + views = [n for n in self.nodes if n.kind == "Class" and n.extra.get("sql_kind") == "view"] + names = {v.name for v in views} + assert "active_orders" in names + + def test_finds_function(self): + funcs = [ + n for n in self.nodes + if n.kind == "Function" and n.extra.get("sql_kind") == "function" + ] + names = {f.name for f in funcs} + assert "get_user_total" in names + + def test_finds_procedure(self): + procs = [ + n for n in self.nodes + if n.kind == "Function" and n.extra.get("sql_kind") == "procedure" + ] + names = {p.name for p in procs} + assert "archive_old_orders" in names + + def test_contains_edges(self): + contains = [e for e in self.edges if e.kind == "CONTAINS"] + targets = {e.target.split("::")[-1] for e in contains} + assert "users" in targets + assert "orders" in targets + assert "active_orders" in targets + assert "get_user_total" in targets + assert "archive_old_orders" in targets + + def test_table_reference_edges(self): + imports = [e for e in self.edges if e.kind == "IMPORTS_FROM"] + targets = {e.target for e in imports} + # active_orders view and archive procedure both reference orders/users + assert "orders" in targets or "users" in targets diff --git a/uv.lock b/uv.lock index 62a32add..c40535b3 100644 --- a/uv.lock +++ b/uv.lock @@ -411,6 +411,7 @@ requires-dist = [ { name = "pyyaml", marker = "extra == 'eval'", specifier = ">=6.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.3.0,<1" }, { name = "sentence-transformers", marker = "extra == 'embeddings'", specifier = ">=3.0.0,<4" }, + { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0,<3" }, { name = "tomli", marker = "python_full_version < '3.11' and extra == 'dev'", specifier = ">=2.0" }, { name = "tree-sitter", specifier = ">=0.23.0,<1" }, { name = "tree-sitter-language-pack", specifier = ">=0.3.0,<1" },