Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions code_review_graph/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class CellInfo(NamedTuple):
re.IGNORECASE,
)

# SQL keywords that can appear after FROM/JOIN but are NOT table names.
_SQL_KEYWORDS: frozenset[str] = frozenset({
"SELECT", "WHERE", "GROUP", "ORDER", "HAVING", "LIMIT", "OFFSET",
"UNION", "INTERSECT", "EXCEPT", "AS", "ON", "USING", "SET",
"VALUES", "DEFAULT", "NULL", "TRUE", "FALSE",
"INNER", "OUTER", "LEFT", "RIGHT", "FULL", "CROSS", "NATURAL",
"LATERAL", "RECURSIVE", "ONLY", "WITH",
})

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -125,6 +134,7 @@ class EdgeInfo:
".res": "rescript",
".resi": "rescript",
".gd": "gdscript",
".sql": "sql",
}

# Tree-sitter node type mappings per language
Expand Down Expand Up @@ -172,6 +182,8 @@ class EdgeInfo:
"zig": ["container_declaration"],
"powershell": ["class_statement"],
"julia": ["struct_definition", "abstract_definition"],
# SQL: CREATE TABLE / CREATE VIEW are handled via _parse_sql dispatch.
"sql": [],
}

_FUNCTION_TYPES: dict[str, list[str]] = {
Expand Down Expand Up @@ -222,6 +234,8 @@ class EdgeInfo:
"function_definition",
"short_function_definition",
],
# SQL: CREATE FUNCTION / CREATE PROCEDURE handled via _parse_sql dispatch.
"sql": [],
}

_IMPORT_TYPES: dict[str, list[str]] = {
Expand Down Expand Up @@ -262,6 +276,8 @@ class EdgeInfo:
"powershell": [],
# Julia: import/using are import_statement nodes.
"julia": ["import_statement", "using_statement"],
# SQL: table references extracted as IMPORTS_FROM via _parse_sql dispatch.
"sql": [],
}

_CALL_TYPES: dict[str, list[str]] = {
Expand Down Expand Up @@ -300,6 +316,8 @@ class EdgeInfo:
"zig": ["call_expression", "builtin_call_expr"],
"powershell": ["command_expression"],
"julia": ["call_expression"],
# SQL: no call edges extracted (grammar too unreliable for procedure calls).
"sql": [],
}

# Patterns that indicate a test function
Expand Down Expand Up @@ -682,6 +700,11 @@ def parse_bytes(self, path: Path, source: bytes) -> tuple[list[NodeInfo], list[E
if language == "rescript":
return self._parse_rescript(path, source)

# SQL: dedicated parser — tree-sitter for tables/views/functions +
# regex fallback for CREATE PROCEDURE (unsupported by the grammar).
if language == "sql":
return self._parse_sql(path, source)

parser = self._get_parser(language)
if not parser:
return [], []
Expand Down Expand Up @@ -1675,6 +1698,184 @@ def enclosing_module(off: int) -> Optional[str]:

return nodes, edges

# ------------------------------------------------------------------
# SQL parser
# ------------------------------------------------------------------

# Regex for CREATE PROCEDURE — tree-sitter SQL grammar emits an ERROR node
# for this statement, so we fall back to a regex scan.
_SQL_PROC_RE = re.compile(
r"CREATE\s+(?:OR\s+REPLACE\s+)?PROCEDURE\s+(\w+(?:\.\w+)*)",
re.IGNORECASE,
)

# Named DDL statements supported by tree-sitter-sql.
_SQL_DDL_NODE_TYPES = frozenset({
"create_table",
"create_view",
"create_function",
})

def _parse_sql(
self, path: Path, source: bytes,
) -> tuple[list[NodeInfo], list[EdgeInfo]]:
"""Parse a `.sql` file.

Extracts:
- Tables (CREATE TABLE) → Class nodes with extra["sql_kind"]="table"
- Views (CREATE VIEW) → Class nodes with extra["sql_kind"]="view"
- Functions (CREATE FUNCTION) → Function nodes with extra["sql_kind"]="function"
- Procedures (CREATE PROCEDURE, regex fallback) → Function nodes with
extra["sql_kind"]="procedure"

Data dependencies (FROM/JOIN table references) are recorded as
IMPORTS_FROM edges so the impact-radius query can follow them.
"""
text = source.decode("utf-8", errors="replace")
file_path_str = str(path)
test_file = _is_test_file(file_path_str)

nodes: list[NodeInfo] = []
edges: list[EdgeInfo] = []

nodes.append(NodeInfo(
kind="File",
name=file_path_str,
file_path=file_path_str,
line_start=1,
line_end=text.count("\n") + 1,
language="sql",
is_test=test_file,
))

# --- tree-sitter pass ---
parser = self._get_parser("sql")
if parser:
tree = parser.parse(source)
self._walk_sql_tree(
tree.root_node, source, file_path_str, nodes, edges,
)

# --- regex fallback for CREATE PROCEDURE ---
for m in self._SQL_PROC_RE.finditer(text):
raw_name = m.group(1)
name = raw_name.split(".")[-1] # strip schema prefix
line = text[: m.start()].count("\n") + 1
qualified = f"{file_path_str}::{name}"
nodes.append(NodeInfo(
kind="Function",
name=name,
file_path=file_path_str,
line_start=line,
line_end=line,
language="sql",
extra={"sql_kind": "procedure"},
))
edges.append(EdgeInfo(
kind="CONTAINS",
source=file_path_str,
target=qualified,
file_path=file_path_str,
line=line,
))

# --- table-reference pass (FROM / JOIN targets) ---
seen_refs: set[str] = set()
for m in _SQL_TABLE_RE.finditer(text):
raw_ref = m.group(1).strip("`")
ref = raw_ref.split(".")[-1] # strip schema/db prefix
if ref and ref.upper() not in _SQL_KEYWORDS and ref not in seen_refs:
seen_refs.add(ref)
line = text[: m.start()].count("\n") + 1
edges.append(EdgeInfo(
kind="IMPORTS_FROM",
source=file_path_str,
target=ref,
file_path=file_path_str,
line=line,
))

return nodes, edges

def _walk_sql_tree(
self,
node,
source: bytes,
file_path_str: str,
nodes: list[NodeInfo],
edges: list[EdgeInfo],
) -> None:
"""Recursively walk a tree-sitter SQL AST and extract DDL entities."""
if node.type in self._SQL_DDL_NODE_TYPES:
self._extract_sql_ddl(node, source, file_path_str, nodes, edges)
return # don't recurse into the DDL body — no nested DDL expected
for child in node.children:
self._walk_sql_tree(child, source, file_path_str, nodes, edges)

def _extract_sql_ddl(
self,
node,
source: bytes,
file_path_str: str,
nodes: list[NodeInfo],
edges: list[EdgeInfo],
) -> None:
"""Extract a single CREATE TABLE / VIEW / FUNCTION DDL node."""
node_type = node.type
line_start = node.start_point[0] + 1
line_end = node.end_point[0] + 1

# Locate the identifier / object_reference child that holds the name.
name: Optional[str] = None
for child in node.children:
if child.type in ("identifier", "object_reference", "dotted_name"):
raw = source[child.start_byte: child.end_byte].decode("utf-8", errors="replace")
# Strip schema prefix (schema.name → name)
name = raw.strip("`\"").split(".")[-1]
break
# Some grammars nest: relation > object_reference > identifier
if child.type == "relation":
for gc in child.children:
if gc.type in ("object_reference", "identifier"):
raw = source[gc.start_byte: gc.end_byte].decode(
"utf-8", errors="replace",
)
name = raw.strip("`\"").split(".")[-1]
break
if name:
break

if not name:
return

if node_type == "create_table":
kind = "Class"
sql_kind = "table"
elif node_type == "create_view":
kind = "Class"
sql_kind = "view"
else: # create_function
kind = "Function"
sql_kind = "function"

qualified = f"{file_path_str}::{name}"
nodes.append(NodeInfo(
kind=kind,
name=name,
file_path=file_path_str,
line_start=line_start,
line_end=line_end,
language="sql",
extra={"sql_kind": sql_kind},
))
edges.append(EdgeInfo(
kind="CONTAINS",
source=file_path_str,
target=qualified,
file_path=file_path_str,
line=line_start,
))

def _resolve_call_targets(
self,
nodes: list[NodeInfo],
Expand Down
37 changes: 37 additions & 0 deletions tests/fixtures/sample.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
-- Sample SQL fixture for code-review-graph parser tests

CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
);

CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER REFERENCES users(id),
total NUMERIC(10, 2),
created_at TIMESTAMP DEFAULT NOW()
);

CREATE VIEW active_orders AS
SELECT o.id, u.name, o.total
FROM orders o
JOIN users u ON u.id = o.user_id
WHERE o.total > 0;

CREATE FUNCTION get_user_total(p_user_id INTEGER)
RETURNS NUMERIC AS $$
SELECT SUM(total)
FROM orders
WHERE user_id = p_user_id;
$$ LANGUAGE sql;

CREATE OR REPLACE PROCEDURE archive_old_orders(cutoff_date DATE)
LANGUAGE plpgsql AS $$
BEGIN
INSERT INTO orders_archive
SELECT * FROM orders WHERE created_at < cutoff_date;

DELETE FROM orders WHERE created_at < cutoff_date;
END;
$$;
56 changes: 56 additions & 0 deletions tests/test_multilang.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,3 +1916,59 @@ def test_resolver_is_idempotent(self, tmp_path):
# Second run should find nothing new — all already resolved.
assert second["calls_resolved"] == 0
assert second["imports_resolved"] == 0


class TestSQLParsing:
def setup_method(self):
self.parser = CodeParser()
self.nodes, self.edges = self.parser.parse_file(FIXTURES / "sample.sql")

def test_detects_language(self):
assert self.parser.detect_language(Path("schema.sql")) == "sql"

def test_file_node(self):
file_nodes = [n for n in self.nodes if n.kind == "File"]
assert len(file_nodes) == 1
assert file_nodes[0].language == "sql"

def test_finds_tables(self):
tables = [n for n in self.nodes if n.kind == "Class" and n.extra.get("sql_kind") == "table"]
names = {t.name for t in tables}
assert "users" in names
assert "orders" in names

def test_finds_view(self):
views = [n for n in self.nodes if n.kind == "Class" and n.extra.get("sql_kind") == "view"]
names = {v.name for v in views}
assert "active_orders" in names

def test_finds_function(self):
funcs = [
n for n in self.nodes
if n.kind == "Function" and n.extra.get("sql_kind") == "function"
]
names = {f.name for f in funcs}
assert "get_user_total" in names

def test_finds_procedure(self):
procs = [
n for n in self.nodes
if n.kind == "Function" and n.extra.get("sql_kind") == "procedure"
]
names = {p.name for p in procs}
assert "archive_old_orders" in names

def test_contains_edges(self):
contains = [e for e in self.edges if e.kind == "CONTAINS"]
targets = {e.target.split("::")[-1] for e in contains}
assert "users" in targets
assert "orders" in targets
assert "active_orders" in targets
assert "get_user_total" in targets
assert "archive_old_orders" in targets

def test_table_reference_edges(self):
imports = [e for e in self.edges if e.kind == "IMPORTS_FROM"]
targets = {e.target for e in imports}
# active_orders view and archive procedure both reference orders/users
assert "orders" in targets or "users" in targets
1 change: 1 addition & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.