Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Lib/test/test__interpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import time
import unittest

from test import support
from test.support import import_helper
from test.support.script_helper import run_python_until_end

_channels = import_helper.import_module('_interpchannels')
from concurrent.interpreters import _crossinterp
Expand Down Expand Up @@ -308,6 +310,33 @@ def test_bad_kwargs(self):
with self.assertRaises(ValueError):
_channels._channel_id(10, send=False, recv=False)

@unittest.skipIf(support.Py_TRACE_REFS,
'_testcapi.set_nomemory() is unreliable with Py_TRACE_REFS')
def test_oom_in_module_lookup(self):
import_helper.import_module('_testcapi')
code = dedent("""
import _interpchannels
import _testcapi
from test.support import SuppressCrashReport

seen = False
with SuppressCrashReport():
_testcapi.set_nomemory(0, 1)
try:
_interpchannels._channel_id(0)
except MemoryError:
seen = True
finally:
_testcapi.remove_mem_hooks()
if not seen:
raise AssertionError("MemoryError not raised")
print("MemoryError")
""")
with support.SuppressCrashReport():
res, _ = run_python_until_end("-c", code)
self.assertEqual(res.rc, 0, res.err.decode("ascii", "replace"))
self.assertIn(b"MemoryError", res.out)

def test_does_not_exist(self):
cid = _channels.create(REPLACE)
with self.assertRaises(_channels.ChannelNotFoundError):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix a crash in ``_interpchannels._channel_id()`` when memory allocation
fails while looking up module state. The function now correctly propagates
``MemoryError``.
3 changes: 3 additions & 0 deletions Modules/_interpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,9 @@ channelsmod__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
PyTypeObject *cls = state->ChannelIDType;

PyObject *mod = get_module_from_owned_type(cls);
if (mod == NULL) {
return NULL;
}
assert(mod == self);
Py_DECREF(mod);

Expand Down
Loading