Skip to content

Commit 7ed5db6

Browse files
committed
Use parameterized queries in HiveStatsCollectionOperator
HiveStatsCollectionOperator builds its bookkeeping SQL (the SELECT and DELETE against hive_stats in the MySQL metastore, and the SELECT ... FROM <table> WHERE <partition_key> = '<value>' against Presto) by f-string-interpolating template-rendered fields (table, partition, dttm) directly into raw SQL strings. Per the security model in airflow-core/docs/security/security_model.rst and airflow-core/docs/security/sql.rst, this is not a vulnerability -- Dag authors are trusted users responsible for sanitizing input before passing it to operators. The change here is defense-in-depth so that the operator does not rely on each Dag author to sanitize. The MySQL bookkeeping SELECT and DELETE now use %s placeholders with the parameters= kwarg of MySqlHook.get_records / .run, so the values are bound by the driver instead of interpolated. For the Presto SELECT, the table name and partition column names cannot be parameterized in standard SQL. They are now validated against ^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$ for the table and ^[A-Za-z_][A-Za-z0-9_]*$ for partition keys, raising AirflowException on mismatch. Partition values are passed as bound parameters.
1 parent 354391b commit 7ed5db6

2 files changed

Lines changed: 131 additions & 26 deletions

File tree

providers/apache/hive/src/airflow/providers/apache/hive/operators/hive_stats.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import json
21+
import re
2122
from collections.abc import Callable, Sequence
2223
from typing import TYPE_CHECKING, Any
2324

@@ -29,6 +30,13 @@
2930
if TYPE_CHECKING:
3031
from airflow.providers.common.compat.sdk import Context
3132

33+
# Hive table names may be qualified as `<database>.<table>`; identifiers must
34+
# be plain word characters so they can be safely interpolated into the Presto
35+
# query that selects partition stats. Identifiers cannot be bound as parameters
36+
# in standard SQL.
37+
_HIVE_TABLE_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
38+
_HIVE_COLUMN_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
39+
3240

3341
class HiveStatsCollectionOperator(BaseOperator):
3442
"""
@@ -112,6 +120,16 @@ def get_default_exprs(self, col: str, col_type: str) -> dict[Any, Any]:
112120
return exp
113121

114122
def execute(self, context: Context) -> None:
123+
if not _HIVE_TABLE_RE.match(self.table):
124+
raise AirflowException(
125+
f"Invalid Hive table identifier: {self.table!r}. Must match {_HIVE_TABLE_RE.pattern}."
126+
)
127+
for partition_key in self.partition.keys():
128+
if not _HIVE_COLUMN_RE.match(partition_key):
129+
raise AirflowException(
130+
f"Invalid partition column name: {partition_key!r}. Must match {_HIVE_COLUMN_RE.pattern}."
131+
)
132+
115133
metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
116134
table = metastore.get_table(table_name=self.table)
117135
field_types = {col.name: col.type for col in table.sd.cols}
@@ -128,13 +146,13 @@ def execute(self, context: Context) -> None:
128146
exprs.update(self.extra_exprs)
129147
exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items())
130148

131-
where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()]
149+
where_clause_ = [f"{k} = %s" for k in self.partition.keys()]
132150
where_clause = " AND\n ".join(where_clause_)
133151
sql = f"SELECT {exprs_str} FROM {self.table} WHERE {where_clause};"
134152

135153
presto = PrestoHook(presto_conn_id=self.presto_conn_id)
136154
self.log.info("Executing SQL check: %s", sql)
137-
row = presto.get_first(sql)
155+
row = presto.get_first(sql, parameters=tuple(self.partition.values()))
138156
self.log.info("Record: %s", row)
139157
if not row:
140158
raise AirflowException("The query returned None")
@@ -143,23 +161,23 @@ def execute(self, context: Context) -> None:
143161

144162
self.log.info("Deleting rows from previous runs if they exist")
145163
mysql = MySqlHook(self.mysql_conn_id)
146-
sql = f"""
164+
sql = """
147165
SELECT 1 FROM hive_stats
148166
WHERE
149-
table_name='{self.table}' AND
150-
partition_repr='{part_json}' AND
151-
dttm='{self.dttm}'
167+
table_name = %s AND
168+
partition_repr = %s AND
169+
dttm = %s
152170
LIMIT 1;
153171
"""
154-
if mysql.get_records(sql):
155-
sql = f"""
172+
if mysql.get_records(sql, parameters=(self.table, part_json, self.dttm)):
173+
sql = """
156174
DELETE FROM hive_stats
157175
WHERE
158-
table_name='{self.table}' AND
159-
partition_repr='{part_json}' AND
160-
dttm='{self.dttm}';
176+
table_name = %s AND
177+
partition_repr = %s AND
178+
dttm = %s;
161179
"""
162-
mysql.run(sql)
180+
mysql.run(sql, parameters=(self.table, part_json, self.dttm))
163181

164182
self.log.info("Pivoting and loading cells into the Airflow db")
165183
rows = [

providers/apache/hive/tests/unit/apache/hive/operators/test_hive_stats.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,97 @@ def test_execute_delete_previous_runs_rows(
292292
hive_stats_collection_operator = HiveStatsCollectionOperator(**self.kwargs)
293293
hive_stats_collection_operator.execute(context={})
294294

295-
sql = f"""
295+
expected_sql = """
296296
DELETE FROM hive_stats
297297
WHERE
298-
table_name='{hive_stats_collection_operator.table}' AND
299-
partition_repr='{mock_json_dumps.return_value}' AND
300-
dttm='{hive_stats_collection_operator.dttm}';
298+
table_name = %s AND
299+
partition_repr = %s AND
300+
dttm = %s;
301301
"""
302-
mock_mysql_hook.return_value.run.assert_called_once_with(sql)
302+
mock_mysql_hook.return_value.run.assert_called_once_with(
303+
expected_sql,
304+
parameters=(
305+
hive_stats_collection_operator.table,
306+
mock_json_dumps.return_value,
307+
hive_stats_collection_operator.dttm,
308+
),
309+
)
310+
311+
@patch("airflow.providers.apache.hive.operators.hive_stats.MySqlHook")
312+
@patch("airflow.providers.apache.hive.operators.hive_stats.PrestoHook")
313+
@patch("airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook")
314+
def test_execute_rejects_invalid_table_identifier(
315+
self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook
316+
):
317+
# The Presto SELECT interpolates the table identifier; the operator
318+
# rejects any value that does not match the <db>.<table> allowlist
319+
# so callers cannot smuggle whitespace or punctuation into the
320+
# identifier position.
321+
self.kwargs["table"] = "evil; DROP TABLE users--"
322+
with pytest.raises(AirflowException, match="Invalid Hive table identifier"):
323+
HiveStatsCollectionOperator(**self.kwargs).execute(context={})
324+
325+
@patch("airflow.providers.apache.hive.operators.hive_stats.MySqlHook")
326+
@patch("airflow.providers.apache.hive.operators.hive_stats.PrestoHook")
327+
@patch("airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook")
328+
def test_execute_rejects_invalid_partition_column(
329+
self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook
330+
):
331+
# Partition keys reach the SELECT clause as column identifiers and
332+
# are validated against the same allowlist.
333+
self.kwargs["partition"] = {"evil col": "value"}
334+
with pytest.raises(AirflowException, match="Invalid partition column name"):
335+
HiveStatsCollectionOperator(**self.kwargs).execute(context={})
336+
337+
@patch("airflow.providers.apache.hive.operators.hive_stats.json.dumps")
338+
@patch("airflow.providers.apache.hive.operators.hive_stats.MySqlHook")
339+
@patch("airflow.providers.apache.hive.operators.hive_stats.PrestoHook")
340+
@patch("airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook")
341+
def test_execute_parameterizes_mysql_bookkeeping_queries(
342+
self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps
343+
):
344+
# The bookkeeping SELECT and DELETE against hive_stats bind table,
345+
# partition_repr, and dttm as %s parameters instead of interpolating
346+
# them into the SQL body, so the operator does not rely on the
347+
# caller to escape those values.
348+
mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col]
349+
mock_mysql_hook.return_value.get_records.return_value = True
350+
351+
op = HiveStatsCollectionOperator(**self.kwargs)
352+
op.execute(context={})
353+
354+
select_call = mock_mysql_hook.return_value.get_records.call_args
355+
delete_call = mock_mysql_hook.return_value.run.call_args
356+
357+
select_sql = select_call.args[0]
358+
delete_sql = delete_call.args[0]
359+
assert "%s" in select_sql
360+
assert "%s" in delete_sql
361+
assert op.table not in select_sql
362+
assert op.table not in delete_sql
363+
364+
expected_params = (op.table, mock_json_dumps.return_value, op.dttm)
365+
assert select_call.kwargs["parameters"] == expected_params
366+
assert delete_call.kwargs["parameters"] == expected_params
367+
368+
@patch("airflow.providers.apache.hive.operators.hive_stats.MySqlHook")
369+
@patch("airflow.providers.apache.hive.operators.hive_stats.PrestoHook")
370+
@patch("airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook")
371+
def test_execute_parameterizes_presto_partition_values(
372+
self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook
373+
):
374+
# Partition values cannot influence the Presto SQL body — they are
375+
# passed as bound parameters alongside the SELECT.
376+
mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col]
377+
mock_mysql_hook.return_value.get_records.return_value = False
378+
379+
self.kwargs["partition"] = {"col": "value"}
380+
HiveStatsCollectionOperator(**self.kwargs).execute(context={})
381+
382+
presto_call = mock_presto_hook.return_value.get_first.call_args
383+
assert "col = %s" in presto_call.args[0]
384+
assert "'value'" not in presto_call.args[0]
385+
assert presto_call.kwargs["parameters"] == ("value",)
303386

304387
@pytest.mark.skipif(
305388
"AIRFLOW_RUNALL_TESTS" not in os.environ, reason="Skipped because AIRFLOW_RUNALL_TESTS is not set"
@@ -326,23 +409,27 @@ def test_runs_for_hive_stats(self, mock_hive_metastore_hook):
326409
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
327410

328411
select_count_query = (
329-
"SELECT COUNT(*) AS __count FROM airflow.static_babynames_partitioned WHERE ds = '2015-01-01';"
412+
"SELECT COUNT(*) AS __count FROM airflow.static_babynames_partitioned WHERE ds = %s;"
330413
)
331-
mock_presto_hook.get_first.assert_called_with(hql=select_count_query)
414+
presto_call = mock_presto_hook.get_first.call_args
415+
actual_presto_query = re.sub(r"\s{2,}", " ", presto_call.args[0]).strip()
416+
assert actual_presto_query == select_count_query
417+
assert presto_call.kwargs["parameters"] == ("2015-01-01",)
332418

333419
expected_stats_select_query = (
334-
"SELECT 1 "
335-
"FROM hive_stats "
336-
"WHERE table_name='airflow.static_babynames_partitioned' "
337-
' AND partition_repr=\'{"ds": "2015-01-01"}\' '
338-
" AND dttm='2015-01-01T00:00:00+00:00' "
339-
"LIMIT 1;"
420+
"SELECT 1 FROM hive_stats WHERE table_name = %s AND partition_repr = %s AND dttm = %s LIMIT 1;"
340421
)
341422

342-
raw_stats_select_query = mock_mysql_hook.get_records.call_args_list[0][0][0]
423+
stats_select_call = mock_mysql_hook.get_records.call_args_list[0]
424+
raw_stats_select_query = stats_select_call[0][0]
343425
actual_stats_select_query = re.sub(r"\s{2,}", " ", raw_stats_select_query).strip()
344426

345427
assert expected_stats_select_query == actual_stats_select_query
428+
assert stats_select_call.kwargs["parameters"] == (
429+
"airflow.static_babynames_partitioned",
430+
'{"ds": "2015-01-01"}',
431+
"2015-01-01T00:00:00+00:00",
432+
)
346433

347434
insert_rows_val = [
348435
(

0 commit comments

Comments
 (0)