diff --git a/tests/unit/test_executemany.py b/tests/unit/test_executemany.py new file mode 100644 index 00000000..217b6aa7 --- /dev/null +++ b/tests/unit/test_executemany.py @@ -0,0 +1,191 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from decimal import Decimal + +import pytest + +import trino.client +from trino.dbapi import _INSERT_VALUES_RE +from trino.dbapi import Connection +from trino.dbapi import Cursor + + +class FakeTrinoQuery: + """Hand-written fake replacing TrinoQuery for unit tests. + + Records every SQL string passed and behaves like a successful INSERT. + """ + + instances = [] + + def __init__(self, request, query, legacy_primitive_types=False, fetch_mode="mapped"): + self._query = query + self._update_type = "INSERT" + FakeTrinoQuery.instances.append(self) + + def execute(self): + return iter([]) + + @property + def query(self): + return self._query + + @property + def update_type(self): + return self._update_type + + +class FakeTrinoRequest: + pass + + +@pytest.fixture +def cursor(): + """Create a real Cursor wired to FakeTrinoQuery. + + Monkeypatches trino.client.TrinoQuery so the production + _executemany_batch_insert runs its real code path but creates + fakes instead of making HTTP calls. + """ + FakeTrinoQuery.instances = [] + original = trino.client.TrinoQuery + trino.client.TrinoQuery = FakeTrinoQuery + + cur = Cursor.__new__(Cursor) + cur._connection = Connection.__new__(Connection) + cur._connection._client_session = type("cs", (), {"timezone": None})() + cur._request = FakeTrinoRequest() + cur._iterator = None + cur._query = None + cur._legacy_primitive_types = False + + yield cur + + trino.client.TrinoQuery = original + + +class TestInsertValuesPattern: + def test_simple_insert(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t (a, b) VALUES (?, ?)") is not None + + def test_insert_with_schema(self): + sql = 'INSERT INTO "my_schema"."my_table" (col1, col2) VALUES (?, ?)' + assert _INSERT_VALUES_RE.match(sql) is not None + + def test_insert_with_catalog_schema(self): + sql = 'INSERT INTO "catalog"."schema"."table" (a, b, c) VALUES (?, ?, ?)' + assert _INSERT_VALUES_RE.match(sql) is not None + + def test_multiline_insert(self): + sql = ' INSERT INTO "schema"."table" (col1, col2)\n VALUES (?, ?)\n ' + assert _INSERT_VALUES_RE.match(sql.strip()) is not None + + def test_insert_no_columns(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t VALUES (?, ?)") is not None + + def test_select_not_matched(self): + assert _INSERT_VALUES_RE.match("SELECT * FROM t WHERE a = ?") is None + + def test_update_not_matched(self): + assert _INSERT_VALUES_RE.match("UPDATE t SET a = ? WHERE b = ?") is None + + def test_insert_select_not_matched(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t SELECT * FROM s") is None + + def test_trailing_semicolon_not_matched(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t (a) VALUES (?);") is None + + def test_case_insensitive(self): + assert _INSERT_VALUES_RE.match("insert into t values (?)") is not None + + def test_prefix_extraction(self): + sql = 'INSERT INTO "s"."t" (a, b) VALUES (?, ?)' + m = _INSERT_VALUES_RE.match(sql) + assert m is not None + assert m.group(1).strip().endswith("VALUES") + + +class TestExecutemanyBatchInsert: + def test_single_row(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, "hello")] + ) + assert len(FakeTrinoQuery.instances) == 1 + assert "VALUES (1, 'hello')" in FakeTrinoQuery.instances[0].query + + def test_multiple_rows_single_batch(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, "a"), (2, "b"), (3, "c")] + ) + assert len(FakeTrinoQuery.instances) == 1 + sql = FakeTrinoQuery.instances[0].query + assert "(1, 'a')" in sql + assert "(2, 'b')" in sql + assert "(3, 'c')" in sql + + def test_chunking(self, cursor): + import trino.dbapi as _dbapi + original = _dbapi._EXECUTEMANY_BATCH_SIZE + _dbapi._EXECUTEMANY_BATCH_SIZE = 2 + try: + cursor.executemany( + "INSERT INTO t (a) VALUES (?)", + [(1,), (2,), (3,), (4,), (5,)] + ) + assert len(FakeTrinoQuery.instances) == 3 + assert "(1)" in FakeTrinoQuery.instances[0].query + assert "(2)" in FakeTrinoQuery.instances[0].query + assert "(3)" in FakeTrinoQuery.instances[1].query + assert "(4)" in FakeTrinoQuery.instances[1].query + assert "(5)" in FakeTrinoQuery.instances[2].query + finally: + _dbapi._EXECUTEMANY_BATCH_SIZE = original + + def test_null_values(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, None), (None, "test")] + ) + sql = FakeTrinoQuery.instances[0].query + assert "(1, NULL)" in sql + assert "(NULL, 'test')" in sql + + def test_mixed_types(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b, c, d) VALUES (?, ?, ?, ?)", + [(42, "text", True, Decimal("3.14"))] + ) + sql = FakeTrinoQuery.instances[0].query + assert "42" in sql + assert "'text'" in sql + assert "true" in sql + assert "DECIMAL '3.14'" in sql + + def test_string_escaping(self, cursor): + cursor.executemany( + "INSERT INTO t (a) VALUES (?)", + [("it's a test",)] + ) + assert "it''s a test" in FakeTrinoQuery.instances[0].query + + def test_empty_params_does_not_batch(self, cursor): + cursor.executemany( + "INSERT INTO t (a) VALUES (?)", + [] + ) + # Empty params takes the execute() path, not batch path. + # The FakeTrinoQuery from execute() still gets created but + # via the non-batch code path. + assert cursor._query is not None + assert cursor._query.query == "INSERT INTO t (a) VALUES (?)" diff --git a/trino/dbapi.py b/trino/dbapi.py index 42eeb547..bc39bbb7 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -19,6 +19,7 @@ """ import datetime import math +import re import uuid from collections import OrderedDict from decimal import Decimal @@ -83,6 +84,13 @@ logger = trino.logging.get_logger(__name__) +_INSERT_VALUES_RE = re.compile( + r"\A(\s*INSERT\s+INTO\s+.+\bVALUES)\s*\([^)]+\)\s*\Z", + re.IGNORECASE | re.DOTALL, +) + +_EXECUTEMANY_BATCH_SIZE = 100 + class TimeBoundLRUCache: """A bounded LRU cache which expires entries after a configured number of seconds. @@ -656,15 +664,38 @@ def executemany(self, operation, seq_of_params): Return values are not defined. """ + if not seq_of_params: + self.execute(operation) + return self + + match = _INSERT_VALUES_RE.match(operation.strip()) + if match: + return self._executemany_batch_insert(match.group(1), seq_of_params) + for parameters in seq_of_params[:-1]: self.execute(operation, parameters) self.fetchall() if self._query.update_type is None: raise NotSupportedError("Query must return update type") - if seq_of_params: - self.execute(operation, seq_of_params[-1]) - else: - self.execute(operation) + self.execute(operation, seq_of_params[-1]) + return self + + def _executemany_batch_insert(self, prefix, seq_of_params): + for i in range(0, len(seq_of_params), _EXECUTEMANY_BATCH_SIZE): + batch = seq_of_params[i:i + _EXECUTEMANY_BATCH_SIZE] + value_rows = [] + for params in batch: + formatted = ", ".join(self._format_prepared_param(p) for p in params) + value_rows.append("(%s)" % formatted) + sql = "%s %s" % (prefix, ", ".join(value_rows)) + self._query = trino.client.TrinoQuery( + self._request, query=sql, + legacy_primitive_types=self._legacy_primitive_types) + self._iterator = iter(self._query.execute()) + if self._query.update_type is None: + raise NotSupportedError("Query must return update type") + if i + _EXECUTEMANY_BATCH_SIZE < len(seq_of_params): + self.fetchall() return self def fetchone(self) -> Optional[List[Any]]: