diff --git a/mssql_python/exceptions.py b/mssql_python/exceptions.py index f2285bce..ddf0fde0 100644 --- a/mssql_python/exceptions.py +++ b/mssql_python/exceptions.py @@ -30,6 +30,9 @@ def __init__(self, errors: list) -> None: message = "Connection string parsing failed:\n " + "\n ".join(errors) super().__init__(message) + def __reduce__(self): + return (self.__class__, (self.errors,)) + class Exception(builtins.Exception): """ @@ -47,6 +50,23 @@ def __init__(self, driver_error: str, ddbc_error: str) -> None: self.message = f"Driver Error: {self.driver_error}" super().__init__(self.message) + def __reduce__(self): + # Reconstruct without re-running __init__/truncate_error_message() to avoid + # emitting warnings for already-truncated "[Microsoft]..." messages. + return ( + Exception._unpickle, + (self.__class__, self.driver_error, self.ddbc_error, self.message), + ) + + @staticmethod + def _unpickle(cls, driver_error: str, ddbc_error: str, message: str): + obj = cls.__new__(cls) + obj.driver_error = driver_error + obj.ddbc_error = ddbc_error + obj.message = message + builtins.Exception.__init__(obj, message) + return obj + class Warning(Exception): """ diff --git a/tests/test_006_exceptions.py b/tests/test_006_exceptions.py index c763ed55..b0d11776 100644 --- a/tests/test_006_exceptions.py +++ b/tests/test_006_exceptions.py @@ -451,3 +451,76 @@ def test_truncate_error_message_return_paths(): # If the exception handling worked, it would have been caught # and the function would return the original message (line 531) pass + + +# --------------------------------------------------------------------------- +# Pickle / unpickle round-trip tests +# --------------------------------------------------------------------------- + + +def test_exception_pickle_roundtrip(): + """All DB-API exception subclasses must survive a pickle round-trip.""" + import pickle + import copy + + exception_classes = [ + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, + ] + + for cls in exception_classes: + original = cls("driver msg", "ddbc msg") + + # pickle round-trip + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is cls, f"{cls.__name__}: type mismatch after unpickle" + assert restored.driver_error == "driver msg", f"{cls.__name__}: driver_error mismatch" + assert restored.ddbc_error == "ddbc msg", f"{cls.__name__}: ddbc_error mismatch" + assert str(restored) == str(original), f"{cls.__name__}: str() mismatch" + + # copy.deepcopy also uses __reduce__ + deep = copy.deepcopy(original) + assert type(deep) is cls + assert deep.driver_error == "driver msg" + + +def test_exception_pickle_empty_ddbc_error(): + """Exceptions with empty ddbc_error should also round-trip cleanly.""" + import pickle + + original = ProgrammingError("cursor is closed", "") + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is ProgrammingError + assert restored.driver_error == "cursor is closed" + assert restored.ddbc_error == "" + assert str(restored) == str(original) + + +def test_connection_string_parse_error_pickle_roundtrip(): + """ConnectionStringParseError should survive a pickle round-trip.""" + import pickle + import copy + + errors = ["Unknown keyword: foo", "Missing value for: bar"] + original = ConnectionStringParseError(errors) + + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is ConnectionStringParseError + assert restored.errors == errors + assert str(restored) == str(original) + + # copy.deepcopy + deep = copy.deepcopy(original) + assert type(deep) is ConnectionStringParseError + assert deep.errors == errors