Skip to content

Commit 556c6bf

Browse files
committed
Adjust deallocation stream for legacy memory resources to avoid platform-dependent errors. Add dependence on mempool_device where needed for certain tests. Touch-ups.
1 parent af22c81 commit 556c6bf

7 files changed

Lines changed: 36 additions & 39 deletions

File tree

cuda_core/cuda/core/experimental/_memory/_graph_memory_resource.pyx

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,13 @@ cdef class cyGraphMemoryResource(MemoryResource):
148148

149149
class GraphMemoryResource(cyGraphMemoryResource):
150150
"""
151-
A memory resource managing the graph-specific memory pool.
151+
A memory resource for memory related to graphs.
152152
153-
Graph-captured memory operations use a special internal memory pool, which
154-
is a per-device singleton. This class serves as the interface to that pool.
155153
The only supported operations are allocation, deallocation, and a limited
156154
set of status queries.
157155
158-
This memory resource should be used to allocate memory when graph capturing
159-
is enabled. Using this when graphs are not being captured will result in a
160-
runtime error.
156+
This memory resource should be used when building graphs. Using this when
157+
graphs capture is not enabled will result in a runtime error.
161158
162159
Conversely, allocating memory from a `DeviceMemoryResource` when graph
163160
capturing is enabled results in a runtime error.

cuda_core/cuda/core/experimental/_memory/_legacy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def deallocate(self, ptr: DevicePointerT, size, stream):
6262
stream : Stream
6363
The stream on which to perform the deallocation synchronously.
6464
"""
65-
stream.sync()
65+
if stream is not None:
66+
stream.sync()
6667
(err,) = driver.cuMemFreeHost(ptr)
6768
raise_if_driver_error(err)
6869

@@ -97,10 +98,11 @@ def allocate(self, size, stream=None) -> Buffer:
9798
stream = default_stream()
9899
err, ptr = driver.cuMemAlloc(size)
99100
raise_if_driver_error(err)
100-
return Buffer._init(ptr, size, self)
101+
return Buffer._init(ptr, size, self, stream)
101102

102103
def deallocate(self, ptr, size, stream):
103-
stream.sync()
104+
if stream is not None:
105+
stream.sync()
104106
(err,) = driver.cuMemFree(ptr)
105107
raise_if_driver_error(err)
106108

cuda_core/cuda/core/experimental/_memory/_virtual_memory_resource.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121

2222
if TYPE_CHECKING:
23-
from cuda.core.experimental._graph import GraphBuilder
2423
from cuda.core.experimental._stream import Stream
2524

2625
__all__ = ["VirtualMemoryResourceOptions", "VirtualMemoryResource"]
@@ -465,7 +464,7 @@ def _build_access_descriptors(self, prop: driver.CUmemAllocationProp) -> list:
465464

466465
return descs
467466

468-
def allocate(self, size: int, stream: Stream | GraphBuilder | None = None) -> Buffer:
467+
def allocate(self, size: int, stream: Stream | None = None) -> Buffer:
469468
"""
470469
Allocate a buffer of the given size using CUDA virtual memory.
471470
@@ -549,7 +548,7 @@ def allocate(self, size: int, stream: Stream | GraphBuilder | None = None) -> Bu
549548
buf = Buffer.from_handle(ptr=ptr, size=aligned_size, mr=self)
550549
return buf
551550

552-
def deallocate(self, ptr: int, size: int, stream: Stream | GraphBuilder | None = None) -> None:
551+
def deallocate(self, ptr: int, size: int, stream: Stream | None = None) -> None:
553552
"""
554553
Deallocate memory on the device using CUDA VMM APIs.
555554
"""

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,9 @@ cdef cydriver.CUstream _handle_from_stream_protocol(obj) except*:
453453
# needed checks and returns the relevant stream.
454454
cdef Stream Stream_accept(arg, bint allow_stream_protocol=False):
455455
if isinstance(arg, Stream):
456-
return <Stream> arg
456+
return <Stream>(arg)
457457
elif isinstance(arg, GraphBuilder):
458-
return <Stream> arg.stream
458+
return <Stream>(arg.stream)
459459
elif allow_stream_protocol:
460460
try:
461461
stream = Stream._init(arg)
@@ -469,5 +469,5 @@ cdef Stream Stream_accept(arg, bint allow_stream_protocol=False):
469469
stacklevel=2,
470470
category=DeprecationWarning,
471471
)
472-
return <Stream> stream
472+
return <Stream>(stream)
473473
raise TypeError(f"Stream or GraphBuilder expected, got {type(arg).__name__}")

cuda_core/tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,16 @@ def ipc_memory_resource(ipc_device):
110110
return mr
111111

112112

113+
@pytest.fixture
114+
def mempool_device():
115+
"""Obtains a device suitable for mempool tests, or skips."""
116+
device = Device()
117+
device.set_current()
118+
119+
if not device.properties.memory_pools_supported:
120+
pytest.skip("Device does not support mempool operations")
121+
122+
return device
123+
124+
113125
skipif_need_cuda_headers = pytest.mark.skipif(helpers.CUDA_INCLUDE_PATH is None, reason="need CUDA header")

cuda_core/tests/test_graph_mem.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def _common_kernels_alloc():
3434
}
3535
}
3636
"""
37-
arch = "".join(f"{i}" for i in Device().compute_capability)
38-
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
37+
program_options = ProgramOptions(std="c++17", arch=f"sm_{Device().arch}")
3938
prog = Program(code, code_type="c++", options=program_options)
4039
mod = prog.compile("cubin", name_expressions=("set_zero", "add_one"))
4140
return mod
@@ -76,10 +75,10 @@ def free(self, buffers):
7675

7776

7877
@pytest.mark.parametrize("mode", ["no_graph", "global", "thread_local", "relaxed"])
79-
def test_graph_alloc(init_cuda, mode):
78+
def test_graph_alloc(mempool_device, mode):
8079
"""Test basic graph capture with memory allocated and deallocated by GraphMemoryResource."""
8180
NBYTES = 64
82-
device = Device()
81+
device = mempool_device
8382
stream = device.create_stream()
8483
dmr = DeviceMemoryResource(device)
8584
gmr = GraphMemoryResource(device)
@@ -118,10 +117,10 @@ def apply_kernels(mr, stream, out):
118117

119118
@pytest.mark.skipif(IS_WINDOWS or IS_WSL, reason="auto_free_on_launch not supported on Windows")
120119
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
121-
def test_graph_alloc_with_output(init_cuda, mode):
120+
def test_graph_alloc_with_output(mempool_device, mode):
122121
"""Test for memory allocated in a graph being used outside the graph."""
123122
NBYTES = 64
124-
device = Device()
123+
device = mempool_device
125124
stream = device.create_stream()
126125
gmr = GraphMemoryResource(device)
127126

@@ -157,8 +156,8 @@ def test_graph_alloc_with_output(init_cuda, mode):
157156

158157

159158
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
160-
def test_graph_mem_set_attributes(init_cuda, mode):
161-
device = Device()
159+
def test_graph_mem_set_attributes(mempool_device, mode):
160+
device = mempool_device
162161
stream = device.create_stream()
163162
gmr = GraphMemoryResource(device)
164163
mman = GraphMemoryTestManager(gmr, stream, mode)
@@ -209,12 +208,12 @@ def test_graph_mem_set_attributes(init_cuda, mode):
209208

210209

211210
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
212-
def test_gmr_check_capture_state(init_cuda, mode):
211+
def test_gmr_check_capture_state(mempool_device, mode):
213212
"""
214213
Test expected errors (and non-errors) using GraphMemoryResource with graph
215214
capture.
216215
"""
217-
device = Device()
216+
device = mempool_device
218217
stream = device.create_stream()
219218
gmr = GraphMemoryResource(device)
220219

@@ -233,12 +232,12 @@ def test_gmr_check_capture_state(init_cuda, mode):
233232

234233

235234
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
236-
def test_dmr_check_capture_state(init_cuda, mode):
235+
def test_dmr_check_capture_state(mempool_device, mode):
237236
"""
238237
Test expected errors (and non-errors) using DeviceMemoryResource with graph
239238
capture.
240239
"""
241-
device = Device()
240+
device = mempool_device
242241
stream = device.create_stream()
243242
dmr = DeviceMemoryResource(device)
244243

cuda_core/tests/test_memory.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,6 @@
3939
POOL_SIZE = 2097152 # 2MB size
4040

4141

42-
@pytest.fixture(scope="function")
43-
def mempool_device():
44-
"""Obtains a device suitable for mempool tests, or skips."""
45-
device = Device()
46-
device.set_current()
47-
48-
if not device.properties.memory_pools_supported:
49-
pytest.skip("Device does not support mempool operations")
50-
51-
return device
52-
53-
5442
class DummyDeviceMemoryResource(MemoryResource):
5543
def __init__(self, device):
5644
self.device = device

0 commit comments

Comments
 (0)