diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 244afa9c..06edf7bb 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -412,6 +412,9 @@ def _map_sql_type( # pylint: disable=too-many-arguments,too-many-positional-arg logger.debug("_map_sql_type: Mapping param index=%d, type=%s", i, type(param).__name__) if param is None: logger.debug("_map_sql_type: NULL parameter - index=%d", i) + # GH-610: Send SQL_UNKNOWN_TYPE to C++ where the describe-cache + # in BindParameters / BindParameterArray resolves the correct + # type via SQLDescribeParam (cached after first call). return ( ddbc_sql_const.SQL_UNKNOWN_TYPE.value, ddbc_sql_const.SQL_C_DEFAULT.value, @@ -2348,14 +2351,10 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s max_val=max_val, ) - # For executemany with all-NULL columns, SQL_UNKNOWN_TYPE doesn't work - # with array binding. Fall back to SQL_VARCHAR as a safe default. - if ( - sample_value is None - and paraminfo.paramSQLType == ddbc_sql_const.SQL_UNKNOWN_TYPE.value - ): - paraminfo.paramSQLType = ddbc_sql_const.SQL_VARCHAR.value - paraminfo.columnSize = 1 + # GH-610: all-NULL columns now pass SQL_UNKNOWN_TYPE to C++, + # where BindParameterArray resolves the correct type via the + # SQLDescribeParam cache. The previous SQL_VARCHAR hardcoded + # fallback was removed because it broke VARBINARY columns. # Override DECIMAL/NUMERIC to use SQL_C_CHAR string binding. # _map_sql_type may return SQL_C_NUMERIC (expecting NumericData structs) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 5ed7820f..7f1861cd 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -5,20 +5,22 @@ // agnostic will be // taken up in beta release #include "ddbc_bindings.h" -#include "utf_utils.h" #include "connection/connection.h" #include "connection/connection_pool.h" #include "logger_bridge.hpp" +#include "utf_utils.h" + +#include // std::min #include #include #include // For std::memcpy -#include // std::min #include #include // std::setw, std::setfill #include #include // std::forward + //------------------------------------------------------------------------------------------------- // Macro definitions //------------------------------------------------------------------------------------------------- @@ -470,9 +472,48 @@ std::string DescribeChar(unsigned char ch) { } } +// GH-610: Resolve SQL type for a NULL parameter using per-handle cache. +// On cache miss, calls SQLDescribeParam and stores the result. +static DescribedParamInfo ResolveNullParamType( + SqlHandle& handle, SQLHANDLE hStmt, int paramIndex) { + // Check per-handle cache (no mutex — one handle per thread) + auto it = handle.describeCache.find(paramIndex); + if (it != handle.describeCache.end()) { + LOG("ResolveNullParamType: Cache HIT for hStmt=%p param[%d] " + "-> sqlType=%d", (void*)hStmt, paramIndex, it->second.sqlType); + return it->second; + } + + // Cache miss — call SQLDescribeParam + SQLSMALLINT type, digits, nullable; + SQLULEN size; + LOG("ResolveNullParamType: Cache MISS for hStmt=%p param[%d], calling " + "SQLDescribeParam", (void*)hStmt, paramIndex); + RETCODE rc = SQLDescribeParam_ptr( + hStmt, static_cast(paramIndex + 1), + &type, &size, &digits, &nullable); + + DescribedParamInfo info; + if (SQL_SUCCEEDED(rc)) { + info = {type, size, digits, true}; + LOG("ResolveNullParamType: SQLDescribeParam succeeded for param[%d] " + "-> sqlType=%d, columnSize=%lu, decimalDigits=%d", + paramIndex, type, (unsigned long)size, digits); + } else { + info = {SQL_VARCHAR, 1, 0, false}; + LOG_WARNING("ResolveNullParamType: SQLDescribeParam failed for " + "param[%d] (rc=%d), falling back to SQL_VARCHAR", + paramIndex, rc); + } + + // Store in per-handle cache + handle.describeCache[paramIndex] = info; + return info; +} + // Given a list of parameters and their ParamInfo, calls SQLBindParameter on // each of them with appropriate arguments -SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, +SQLRETURN BindParameters(SqlHandle& handle, SQLHANDLE hStmt, const py::list& params, std::vector& paramInfos, std::vector>& paramBuffers, const std::string& charEncoding = "utf-8") { @@ -607,7 +648,8 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, paramBuffers, param.cast()); LOG("BindParameters: param[%d] SQL_C_WCHAR - String " "length=%zu characters, buffer=%zu bytes", - paramIndex, sqlwcharBuffer->size(), sqlwcharBuffer->size() * sizeof(SQLWCHAR)); + paramIndex, sqlwcharBuffer->size(), + sqlwcharBuffer->size() * sizeof(SQLWCHAR)); dataPtr = sqlwcharBuffer->data(); bufferLength = sqlwcharBuffer->size() * sizeof(SQLWCHAR); strLenOrIndPtr = AllocateParamBuffer(paramBuffers); @@ -627,33 +669,15 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, if (!py::isinstance(param)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } + // GH-610: Resolve SQL type for NULL params via per-handle cache. SQLSMALLINT sqlType = paramInfo.paramSQLType; SQLULEN columnSize = paramInfo.columnSize; SQLSMALLINT decimalDigits = paramInfo.decimalDigits; if (sqlType == SQL_UNKNOWN_TYPE) { - SQLSMALLINT describedType; - SQLULEN describedSize; - SQLSMALLINT describedDigits; - SQLSMALLINT nullable; - RETCODE rc = SQLDescribeParam_ptr( - hStmt, static_cast(paramIndex + 1), &describedType, - &describedSize, &describedDigits, &nullable); - if (!SQL_SUCCEEDED(rc)) { - // SQLDescribeParam can fail for generic SELECT statements where - // no table column is referenced. Fall back to SQL_VARCHAR as a safe - // default. - LOG_WARNING("BindParameters: SQLDescribeParam failed for " - "param[%d] (NULL parameter) - SQLRETURN=%d, falling back to " - "SQL_VARCHAR", - paramIndex, rc); - sqlType = SQL_VARCHAR; - columnSize = 1; - decimalDigits = 0; - } else { - sqlType = describedType; - columnSize = describedSize; - decimalDigits = describedDigits; - } + auto resolved = ResolveNullParamType(handle, hStmt, paramIndex); + sqlType = resolved.sqlType; + columnSize = resolved.columnSize; + decimalDigits = resolved.decimalDigits; } dataPtr = nullptr; strLenOrIndPtr = AllocateParamBuffer(paramBuffers); @@ -1358,6 +1382,9 @@ void SqlHandle::markImplicitlyFreed() { */ void SqlHandle::free() { if (_handle && SQLFreeHandle_ptr) { + // GH-610: Clear describe cache to prevent memory leak. + describeCache.clear(); + // Check if Python is shutting down using centralized helper function bool pythonShuttingDown = is_python_finalizing(); @@ -1450,14 +1477,12 @@ SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, const py::object& cat std::u16string catalog = catalogObj.is_none() ? u"" : catalogObj.cast(); std::u16string schema = schemaObj.is_none() ? u"" : schemaObj.cast(); - std::u16string procedure = - procedureObj.is_none() ? u"" : procedureObj.cast(); + std::u16string procedure = procedureObj.is_none() ? u"" : procedureObj.cast(); // Release the GIL during the blocking ODBC catalog call py::gil_scoped_release release; return SQLProcedures_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), + StatementHandle->get(), catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), schema.empty() ? 0 : SQL_NTS, @@ -1509,14 +1534,13 @@ SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, const py::object& ca // Release the GIL during the blocking ODBC catalog call py::gil_scoped_release release; - return SQLPrimaryKeys_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), - table.empty() ? 0 : SQL_NTS); + return SQLPrimaryKeys_ptr(StatementHandle->get(), + catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), + table.empty() ? 0 : SQL_NTS); } SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, const py::object& catalogObj, @@ -1531,14 +1555,13 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, const py::object& cat // Release the GIL during the blocking ODBC catalog call py::gil_scoped_release release; - return SQLStatistics_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), - table.empty() ? 0 : SQL_NTS, unique, reserved); + return SQLStatistics_ptr(StatementHandle->get(), + catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), + table.empty() ? 0 : SQL_NTS, unique, reserved); } SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, const py::object& catalogObj, @@ -1555,16 +1578,15 @@ SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, const py::object& catalo // Release the GIL during the blocking ODBC catalog call py::gil_scoped_release release; - return SQLColumns_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), - table.empty() ? 0 : SQL_NTS, - column.empty() ? nullptr : reinterpretU16stringAsSqlWChar(column), - column.empty() ? 0 : SQL_NTS); + return SQLColumns_ptr(StatementHandle->get(), + catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), + table.empty() ? 0 : SQL_NTS, + column.empty() ? nullptr : reinterpretU16stringAsSqlWChar(column), + column.empty() ? 0 : SQL_NTS); } // Helper function to check for driver errors @@ -1589,8 +1611,9 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET SQLINTEGER nativeError; SQLSMALLINT messageLen; - SQLRETURN diagReturn = SQLGetDiagRec_ptr(handleType, rawHandle, 1, sqlState, &nativeError, - message, SQL_MAX_MESSAGE_LENGTH_SQLSERVER, &messageLen); + SQLRETURN diagReturn = + SQLGetDiagRec_ptr(handleType, rawHandle, 1, sqlState, &nativeError, message, + SQL_MAX_MESSAGE_LENGTH_SQLSERVER, &messageLen); if (SQL_SUCCEEDED(diagReturn)) { std::u16string sqlStateUtf16 = dupeSqlWCharAsUtf16Le(sqlState, 5); @@ -1697,16 +1720,15 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, const std::u16string& cat { // Release the GIL during the blocking ODBC catalog call py::gil_scoped_release release; - ret = SQLTables_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), - table.empty() ? 0 : SQL_NTS, - tableType.empty() ? nullptr : reinterpretU16stringAsSqlWChar(tableType), - tableType.empty() ? 0 : SQL_NTS); + ret = SQLTables_ptr(StatementHandle->get(), + catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), + table.empty() ? 0 : SQL_NTS, + tableType.empty() ? nullptr : reinterpretU16stringAsSqlWChar(tableType), + tableType.empty() ? 0 : SQL_NTS); } LOG("SQLTables: Catalog metadata query %s - SQLRETURN=%d", @@ -1719,8 +1741,7 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, const std::u16string& cat // statement and binds the parameters. Otherwise, it executes the query // directly. 'usePrepare' parameter can be used to disable the prepare step for // queries that might already be prepared in a previous call. -SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, - const std::u16string& query, +SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, const std::u16string& query, const py::list& params, std::vector& paramInfos, py::list& isStmtPrepared, const bool usePrepare, const py::dict& encodingSettings) { @@ -1787,6 +1808,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, rc, (void*)hStmt); return rc; } + // GH-610: Clear per-handle describe cache (new prepare = new param types) + statementHandle->clearDescribeCache(); isStmtPrepared[0] = py::cast(true); } else { // Make sure the statement has been prepared earlier if we're not @@ -1807,7 +1830,7 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } std::vector> paramBuffers; - rc = BindParameters(hStmt, params, paramInfos, paramBuffers, charEncoding); + rc = BindParameters(*statementHandle, hStmt, params, paramInfos, paramBuffers, charEncoding); if (!SQL_SUCCEEDED(rc)) { return rc; } @@ -1955,7 +1978,7 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } } -SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, +SQLRETURN BindParameterArray(SqlHandle& handle, SQLHANDLE hStmt, const py::list& columnwise_params, const std::vector& paramInfos, size_t paramSetSize, std::vector>& paramBuffers, const std::string& charEncoding = "utf-8") { @@ -2543,15 +2566,21 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, } case SQL_C_DEFAULT: { // Handle NULL parameters - all values in this column should be NULL - // The upstream Python type detection (via _compute_column_type) ensures - // SQL_C_DEFAULT is only used when all values are None LOG("BindParameterArray: Binding SQL_C_DEFAULT (NULL) array - param_index=%d, " "count=%zu", paramIndex, paramSetSize); - // For NULL parameters, we need to allocate a minimal buffer and set all - // indicators to SQL_NULL_DATA Use SQL_C_CHAR as a safe default C type for NULL - // values + // GH-610: Resolve SQL type for all-NULL columns via per-handle cache. + SQLSMALLINT resolvedSqlType = info.paramSQLType; + SQLULEN resolvedColSize = info.columnSize; + SQLSMALLINT resolvedDecDigits = info.decimalDigits; + if (resolvedSqlType == SQL_UNKNOWN_TYPE) { + auto resolved = ResolveNullParamType(handle, hStmt, paramIndex); + resolvedSqlType = resolved.sqlType; + resolvedColSize = resolved.columnSize; + resolvedDecDigits = resolved.decimalDigits; + } + char* nullBuffer = AllocateParamBufferArray(tempBuffers, paramSetSize); strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); @@ -2562,7 +2591,14 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, dataPtr = nullBuffer; bufferLength = 1; - LOG("BindParameterArray: SQL_C_DEFAULT bound - param_index=%d", paramIndex); + + // Override info fields so SQLBindParameter below uses resolved type + const_cast(info).paramSQLType = resolvedSqlType; + const_cast(info).columnSize = resolvedColSize; + const_cast(info).decimalDigits = resolvedDecDigits; + + LOG("BindParameterArray: SQL_C_DEFAULT bound - param_index=%d, " + "resolvedSqlType=%d", paramIndex, resolvedSqlType); break; } default: { @@ -2621,6 +2657,8 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::u16 LOG("SQLExecuteMany: SQLPrepare failed - rc=%d", rc); return rc; } + // GH-610: Clear per-handle describe cache (new prepare = new param types) + statementHandle->clearDescribeCache(); LOG("SQLExecuteMany: Query prepared successfully"); bool hasDAE = false; @@ -2643,7 +2681,7 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::u16 "BindParameterArray with encoding '%s'", charEncoding.c_str()); std::vector> paramBuffers; - rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers, + rc = BindParameterArray(*statementHandle, hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers, charEncoding); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecuteMany: BindParameterArray failed - rc=%d", rc); @@ -2673,7 +2711,7 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::u16 py::list rowParams = columnwise_params[rowIndex]; std::vector> paramBuffers; - rc = BindParameters(hStmt, rowParams, const_cast&>(paramInfos), + rc = BindParameters(*statementHandle, hStmt, rowParams, const_cast&>(paramInfos), paramBuffers, charEncoding); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecuteMany: BindParameters failed for row %zu - rc=%d", rowIndex, rc); @@ -2814,9 +2852,8 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta // TODO: Should we define a struct for this task instead of dict? ColumnMetadata.append( py::dict("ColumnName"_a = dupeSqlWCharAsUtf16Le( - ColumnName, - std::min(static_cast(NameLength), - (sizeof(ColumnName) / sizeof(SQLWCHAR)) - 1)), + ColumnName, std::min(static_cast(NameLength), + (sizeof(ColumnName) / sizeof(SQLWCHAR)) - 1)), "DataType"_a = DataType, "ColumnSize"_a = ColumnSize, "DecimalDigits"_a = DecimalDigits, "Nullable"_a = Nullable)); } else { @@ -2838,14 +2875,14 @@ SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT ident std::u16string schema = schemaObj.is_none() ? u"" : schemaObj.cast(); py::gil_scoped_release release; - return SQLSpecialColumns_ptr( - StatementHandle->get(), identifierType, - catalog.empty() ? nullptr : reinterpretU16stringAsSqlWChar(catalog), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), - table.empty() ? 0 : SQL_NTS, scope, nullable); + return SQLSpecialColumns_ptr(StatementHandle->get(), identifierType, + catalog.empty() ? nullptr + : reinterpretU16stringAsSqlWChar(catalog), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : reinterpretU16stringAsSqlWChar(schema), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : reinterpretU16stringAsSqlWChar(table), + table.empty() ? 0 : SQL_NTS, scope, nullable); } // Wrap SQLFetch to retrieve rows @@ -3172,8 +3209,8 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // null termination. This preserves embedded NULs and avoids // any risk of reading past the valid range if the driver // omits the terminator. - row.append( - py::cast(dupeSqlWCharAsUtf16Le(dataBuffer.data(), numCharsInData))); + row.append(py::cast( + dupeSqlWCharAsUtf16Le(dataBuffer.data(), numCharsInData))); LOG("SQLGetData: CHAR column %d fetched as WCHAR, " "length=%lu", i, (unsigned long)numCharsInData); @@ -3338,8 +3375,8 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // null termination. This preserves embedded NULs and avoids // any risk of reading past the valid range if the driver // omits the terminator. - row.append( - py::cast(dupeSqlWCharAsUtf16Le(dataBuffer.data(), numCharsInData))); + row.append(py::cast( + dupeSqlWCharAsUtf16Le(dataBuffer.data(), numCharsInData))); LOG("SQLGetData: Appended NVARCHAR string " "length=%lu for column %d", (unsigned long)numCharsInData, i); diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 0f831097..775b1b77 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -4,6 +4,7 @@ #pragma once // pybind11.h must be the first include +#include #include #include #include @@ -11,10 +12,10 @@ #include #include // Add this line for datetime support #include -#include #include #include + namespace py = pybind11; using py::literals::operator""_a; @@ -232,12 +233,23 @@ class DriverLoader { std::once_flag m_onceFlag; }; +#include + //------------------------------------------------------------------------------------------------- // SqlHandle // // RAII wrapper around ODBC handles (ENV, DBC, STMT). // Use `std::shared_ptr` (alias: SqlHandlePtr) for shared ownership. //------------------------------------------------------------------------------------------------- + +// GH-610: Cached result of SQLDescribeParam for a single parameter. +struct DescribedParamInfo { + SQLSMALLINT sqlType; + SQLULEN columnSize; + SQLSMALLINT decimalDigits; + bool succeeded; // false = SQLDescribeParam failed, used SQL_VARCHAR fallback +}; + class SqlHandle { public: SqlHandle(SQLSMALLINT type, SQLHANDLE rawHandle); @@ -262,6 +274,13 @@ class SqlHandle { // before freeing the DBC handle. void markImplicitlyFreed(); + // GH-610: Per-handle SQLDescribeParam result cache. + // Keyed by 0-based parameter index. Populated on first NULL param + // execution, cleared on SQLPrepare (new SQL = new param types) and + // on handle free. No mutex needed — each handle is used by one thread. + std::unordered_map describeCache; + void clearDescribeCache() { describeCache.clear(); } + private: SQLSMALLINT _type; SQLHANDLE _handle; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index d39f42ae..d43c3b3d 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -2184,6 +2184,127 @@ def test_executemany_MIX_NONE_parameter_list(cursor, db_connection): db_connection.commit() +def test_map_sql_type_none_returns_sql_unknown_type(): + """Test that _map_sql_type returns SQL_UNKNOWN_TYPE for None params (GH-610). + + None returns SQL_UNKNOWN_TYPE so the C++ BindParameters cache can resolve + the correct type via SQLDescribeParam on first call and cache it for + subsequent calls. + """ + from unittest.mock import MagicMock + + from mssql_python.constants import ConstantsDDBC as ddbc_sql_const + + cursor = MagicMock(spec=mssql_python.Cursor) + _map_sql_type = mssql_python.Cursor._map_sql_type.__get__(cursor) + params = [None, 42, None] + + sql_type, c_type, col_size, dec_digits, is_dae = _map_sql_type(None, params, 0) + + assert sql_type == ddbc_sql_const.SQL_UNKNOWN_TYPE.value + assert c_type == ddbc_sql_const.SQL_C_DEFAULT.value + assert col_size == 1 + assert dec_digits == 0 + assert is_dae is False + + +# --------------------------------------------------------- +# GH-610: SQLDescribeParam cache coverage tests +# --------------------------------------------------------- + + +def test_gh610_execute_null_param_cache_miss(cursor, db_connection): + """Cover cache MISS path: first execute with NULL triggers SQLDescribeParam.""" + cursor.execute("CREATE TABLE #gh610_cov1 (id INT, name VARCHAR(50))") + cursor.execute("INSERT INTO #gh610_cov1 VALUES (?, ?)", (1, None)) + db_connection.commit() + cursor.execute("SELECT COUNT(*) FROM #gh610_cov1") + assert cursor.fetchone()[0] == 1 + cursor.execute("DROP TABLE #gh610_cov1") + + +def test_gh610_execute_null_param_cache_hit(cursor, db_connection): + """Cover cache HIT path: repeated execute with same SQL + NULL.""" + cursor.execute("CREATE TABLE #gh610_cov2 (id INT, name VARCHAR(50))") + # First call: cache miss → SQLDescribeParam + cursor.execute("INSERT INTO #gh610_cov2 VALUES (?, ?)", (1, None)) + # Second call: cache hit → no SQLDescribeParam + cursor.execute("INSERT INTO #gh610_cov2 VALUES (?, ?)", (2, None)) + # Third call: cache hit + cursor.execute("INSERT INTO #gh610_cov2 VALUES (?, ?)", (3, None)) + db_connection.commit() + cursor.execute("SELECT COUNT(*) FROM #gh610_cov2") + assert cursor.fetchone()[0] == 3 + cursor.execute("DROP TABLE #gh610_cov2") + + +def test_gh610_cache_invalidation_on_new_sql(cursor, db_connection): + """Cover InvalidateDescribeCache path: different SQL clears cache.""" + cursor.execute("CREATE TABLE #gh610_cov3a (val INT)") + cursor.execute("CREATE TABLE #gh610_cov3b (val VARCHAR(50))") + # First query — cache populated + cursor.execute("INSERT INTO #gh610_cov3a VALUES (?)", (None,)) + # Different SQL — triggers SQLPrepare → InvalidateDescribeCache + cursor.execute("INSERT INTO #gh610_cov3b VALUES (?)", (None,)) + # Back to first — triggers SQLPrepare → InvalidateDescribeCache again + cursor.execute("INSERT INTO #gh610_cov3a VALUES (?)", (None,)) + db_connection.commit() + cursor.execute("SELECT COUNT(*) FROM #gh610_cov3a") + assert cursor.fetchone()[0] == 2 + cursor.execute("DROP TABLE #gh610_cov3a") + cursor.execute("DROP TABLE #gh610_cov3b") + + +def test_gh610_executemany_all_null_column(cursor, db_connection): + """Cover BindParameterArray SQL_C_DEFAULT + SQL_UNKNOWN_TYPE path.""" + cursor.execute("CREATE TABLE #gh610_cov4 (id INT, name VARCHAR(50))") + cursor.executemany( + "INSERT INTO #gh610_cov4 VALUES (?, ?)", + [(1, None), (2, None), (3, None)], + ) + db_connection.commit() + cursor.execute("SELECT COUNT(*) FROM #gh610_cov4 WHERE name IS NULL") + assert cursor.fetchone()[0] == 3 + cursor.execute("DROP TABLE #gh610_cov4") + + +def test_gh610_executemany_multiple_all_null_columns(cursor, db_connection): + """Cover BindParameterArray with multiple all-NULL columns.""" + cursor.execute("CREATE TABLE #gh610_cov5 (id INT, a VARCHAR(50), b INT, c VARCHAR(50))") + cursor.executemany( + "INSERT INTO #gh610_cov5 VALUES (?, ?, ?, ?)", + [(1, None, None, None), (2, None, None, None)], + ) + db_connection.commit() + cursor.execute("SELECT COUNT(*) FROM #gh610_cov5") + assert cursor.fetchone()[0] == 2 + cursor.execute("DROP TABLE #gh610_cov5") + + +def test_gh610_execute_all_null_params(cursor, db_connection): + """Cover BindParameters with all params being NULL.""" + cursor.execute("CREATE TABLE #gh610_cov6 (a INT, b VARCHAR(50))") + cursor.execute("INSERT INTO #gh610_cov6 VALUES (?, ?)", (None, None)) + db_connection.commit() + cursor.execute("SELECT * FROM #gh610_cov6") + row = cursor.fetchone() + assert row[0] is None and row[1] is None + cursor.execute("DROP TABLE #gh610_cov6") + + +def test_gh610_setinputsizes_bypasses_cache(cursor, db_connection): + """setinputsizes provides type directly — cache not used.""" + from mssql_python.constants import ConstantsDDBC as C + + cursor.execute("CREATE TABLE #gh610_cov7 (val VARCHAR(50))") + cursor.setinputsizes([(C.SQL_VARCHAR.value, 50, 0)]) + cursor.execute("INSERT INTO #gh610_cov7 VALUES (?)", (None,)) + db_connection.commit() + cursor.execute("SELECT val FROM #gh610_cov7") + assert cursor.fetchone()[0] is None + cursor.execute("DROP TABLE #gh610_cov7") + + @pytest.mark.skip(reason="Skipping due to commit reliability issues with executemany") def test_executemany_concurrent_null_parameters(db_connection): """Test executemany with NULL parameters across multiple sequential operations."""