|
1 | 1 | # Test the support for SSL and sockets |
2 | 2 |
|
| 3 | +import contextlib |
3 | 4 | import sys |
4 | 5 | import unittest |
5 | 6 | import unittest.mock |
|
47 | 48 |
|
48 | 49 | PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) |
49 | 50 | HOST = socket_helper.HOST |
| 51 | +IS_AWS_LC = "AWS-LC" in ssl.OPENSSL_VERSION |
50 | 52 | IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0) |
51 | 53 | CAN_GET_SELECTED_OPENSSL_GROUP = ssl.OPENSSL_VERSION_INFO >= (3, 2) |
52 | 54 | CAN_IGNORE_UNKNOWN_OPENSSL_GROUPS = ssl.OPENSSL_VERSION_INFO >= (3, 3) |
53 | 55 | CAN_GET_AVAILABLE_OPENSSL_GROUPS = ssl.OPENSSL_VERSION_INFO >= (3, 5) |
54 | 56 | CAN_GET_AVAILABLE_OPENSSL_SIGALGS = ssl.OPENSSL_VERSION_INFO >= (3, 4) |
55 | | -CAN_SET_CLIENT_SIGALGS = "AWS-LC" not in ssl.OPENSSL_VERSION |
| 57 | +CAN_SET_CLIENT_SIGALGS = not IS_AWS_LC |
56 | 58 | CAN_IGNORE_UNKNOWN_OPENSSL_SIGALGS = ssl.OPENSSL_VERSION_INFO >= (3, 3) |
57 | 59 | CAN_GET_SELECTED_OPENSSL_SIGALG = ssl.OPENSSL_VERSION_INFO >= (3, 5) |
58 | 60 | PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS') |
@@ -383,6 +385,20 @@ def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True, |
383 | 385 | return client_context, server_context, hostname |
384 | 386 |
|
385 | 387 |
|
| 388 | +def do_ssl_object_handshake(sslobject, outgoing, max_retry=25): |
| 389 | + """Call do_handshake() on the sslobject and return the sent data. |
| 390 | +
|
| 391 | + If do_handshake() fails more than *max_retry* times, return None. |
| 392 | + """ |
| 393 | + data, attempt = None, 0 |
| 394 | + while not data and attempt < max_retry: |
| 395 | + with contextlib.suppress(ssl.SSLWantReadError): |
| 396 | + sslobject.do_handshake() |
| 397 | + data = outgoing.read() |
| 398 | + attempt += 1 |
| 399 | + return data |
| 400 | + |
| 401 | + |
386 | 402 | class BasicSocketTests(unittest.TestCase): |
387 | 403 |
|
388 | 404 | def test_constants(self): |
@@ -1535,6 +1551,49 @@ def dummycallback(sock, servername, ctx): |
1535 | 1551 | ctx.set_servername_callback(None) |
1536 | 1552 | ctx.set_servername_callback(dummycallback) |
1537 | 1553 |
|
| 1554 | + def test_sni_callback_on_dead_references(self): |
| 1555 | + # See https://github.com/python/cpython/issues/146080. |
| 1556 | + c_ctx = make_test_context() |
| 1557 | + c_inc, c_out = ssl.MemoryBIO(), ssl.MemoryBIO() |
| 1558 | + client = c_ctx.wrap_bio(c_inc, c_out, server_hostname=SIGNED_CERTFILE_HOSTNAME) |
| 1559 | + |
| 1560 | + def sni_callback(sock, servername, ctx): pass |
| 1561 | + sni_callback = unittest.mock.Mock(wraps=sni_callback) |
| 1562 | + s_ctx = make_test_context(server_side=True, certfile=SIGNED_CERTFILE) |
| 1563 | + s_ctx.set_servername_callback(sni_callback) |
| 1564 | + |
| 1565 | + s_inc, s_out = ssl.MemoryBIO(), ssl.MemoryBIO() |
| 1566 | + server = s_ctx.wrap_bio(s_inc, s_out, server_side=True) |
| 1567 | + server_impl = server._sslobj |
| 1568 | + |
| 1569 | + # Perform the handshake on the client side first. |
| 1570 | + data = do_ssl_object_handshake(client, c_out) |
| 1571 | + sni_callback.assert_not_called() |
| 1572 | + if data is None: |
| 1573 | + self.skipTest("cannot establish a handshake from the client") |
| 1574 | + s_inc.write(data) |
| 1575 | + sni_callback.assert_not_called() |
| 1576 | + # Delete the server object before it starts doing its handshake |
| 1577 | + # and ensure that we did not call the SNI callback yet. |
| 1578 | + del server |
| 1579 | + gc.collect() |
| 1580 | + # Try to continue the server's handshake by directly using |
| 1581 | + # the internal SSL object. The latter is a weak reference |
| 1582 | + # stored in the server context and has now a dead owner. |
| 1583 | + with self.assertRaises(ssl.SSLError) as cm: |
| 1584 | + server_impl.do_handshake() |
| 1585 | + # The SNI C callback raised an exception before calling our callback. |
| 1586 | + sni_callback.assert_not_called() |
| 1587 | + |
| 1588 | + # In AWS-LC, any handshake failures reports SSL_R_PARSE_TLSEXT, |
| 1589 | + # while OpenSSL uses SSL_R_CALLBACK_FAILED on SNI callback failures. |
| 1590 | + if IS_AWS_LC: |
| 1591 | + libssl_error_reason = "PARSE_TLSEXT" |
| 1592 | + else: |
| 1593 | + libssl_error_reason = "callback failed" |
| 1594 | + self.assertIn(libssl_error_reason, str(cm.exception)) |
| 1595 | + self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_SSL) |
| 1596 | + |
1538 | 1597 | def test_sni_callback_refcycle(self): |
1539 | 1598 | # Reference cycles through the servername callback are detected |
1540 | 1599 | # and cleared. |
|
0 commit comments