Skip to content

Commit 7042437

Browse files
committed
Simplify Stream_accept. Default arguments can more easily be handled outside that function.
1 parent e5ea645 commit 7042437

6 files changed

Lines changed: 11 additions & 18 deletions

File tree

cuda_core/cuda/core/experimental/_launcher.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern
5454
launching kernel.
5555
5656
"""
57-
stream = Stream_accept(stream, allow_default=False, default_value=None, allow_stream_protocol=True)
57+
stream = Stream_accept(stream, allow_stream_protocol=True)
5858
assert_type(kernel, Kernel)
5959
_lazy_init()
6060
config = check_or_create_options(LaunchConfig, config, "launch config")

cuda_core/cuda/core/experimental/_memory/_buffer.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ cdef class Buffer:
273273
cdef inline void Buffer_close(Buffer self, stream):
274274
cdef Stream s
275275
if self._ptr and self._memory_resource is not None:
276-
s = Stream_accept(stream, allow_default=True, default_value=self._alloc_stream)
276+
s = Stream_accept(stream) if stream is not None else self._alloc_stream
277277
self._memory_resource.deallocate(self._ptr, self._size, s)
278278
self._ptr = 0
279279
self._memory_resource = None

cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ from cuda.bindings cimport cydriver
1212
from cuda.core.experimental._memory._buffer cimport Buffer, MemoryResource
1313
from cuda.core.experimental._memory cimport _ipc
1414
from cuda.core.experimental._memory._ipc cimport IPCAllocationHandle, IPCData
15-
from cuda.core.experimental._stream cimport Stream_accept, Stream
15+
from cuda.core.experimental._stream cimport default_stream, Stream_accept, Stream
1616
from cuda.core.experimental._utils.cuda_utils cimport (
1717
check_or_create_options,
1818
HANDLE_RETURN,
@@ -334,7 +334,7 @@ cdef class DeviceMemoryResource(MemoryResource):
334334
"""
335335
if self.is_mapped:
336336
raise TypeError("Cannot allocate from a mapped IPC-enabled memory resource")
337-
stream = Stream_accept(stream, allow_default=True)
337+
stream = Stream_accept(stream) if stream is not None else default_stream()
338338
return DMR_allocate(self, size, <Stream> stream)
339339

340340
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream | GraphBuilder | None = None):
@@ -351,7 +351,7 @@ cdef class DeviceMemoryResource(MemoryResource):
351351
If the buffer is deallocated without an explicit stream, the allocation stream
352352
is used.
353353
"""
354-
stream = Stream_accept(stream, allow_default=True)
354+
stream = Stream_accept(stream) if stream is not None else default_stream()
355355
DMR_deallocate(self, <uintptr_t>ptr, size, <Stream> stream)
356356

357357
@property

cuda_core/cuda/core/experimental/_memory/_graph_memory_resource.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from libc.stdint cimport intptr_t
88

99
from cuda.bindings cimport cydriver
1010
from cuda.core.experimental._memory._buffer cimport Buffer, MemoryResource
11-
from cuda.core.experimental._stream cimport Stream_accept, Stream
11+
from cuda.core.experimental._stream cimport default_stream, Stream_accept, Stream
1212
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
1313

1414
from functools import cache
@@ -106,14 +106,14 @@ cdef class cyGraphMemoryResource(MemoryResource):
106106
"""
107107
Allocate a buffer of the requested size. See documentation for :obj:`~_memory.MemoryResource`.
108108
"""
109-
stream = Stream_accept(stream, allow_default=True)
109+
stream = Stream_accept(stream) if stream is not None else default_stream()
110110
return GMR_allocate(self, size, <Stream> stream)
111111

112112
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream | GraphBuilder | None = None):
113113
"""
114114
Deallocate a buffer of the requested size. See documentation for :obj:`~_memory.MemoryResource`.
115115
"""
116-
stream = Stream_accept(stream, allow_default=True)
116+
stream = Stream_accept(stream) if stream is not None else default_stream()
117117
return GMR_deallocate(ptr, size, <Stream> stream)
118118

119119
def close(self):

cuda_core/cuda/core/experimental/_stream.pxd

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,4 @@ cdef class Stream:
2222

2323

2424
cpdef Stream default_stream()
25-
cdef Stream Stream_accept(arg, bint allow_default=*, Stream default_value=*, bint allow_stream_protocol=*)
26-
# from cuda.core.experimental._stream cimport Stream_accept
25+
cdef Stream Stream_accept(arg, bint allow_stream_protocol=*)

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,8 @@ cdef cydriver.CUstream _handle_from_stream_protocol(obj) except*:
451451

452452
# Helper for API functions that accept either Stream or GraphBuilder. Performs
453453
# needed checks and returns the relevant stream.
454-
cdef Stream Stream_accept(arg, bint allow_default=False, Stream default_value=None, bint allow_stream_protocol=False):
455-
if arg is None:
456-
if allow_default:
457-
if default_value is not None:
458-
return default_value
459-
else:
460-
return default_stream()
461-
elif isinstance(arg, Stream):
454+
cdef Stream Stream_accept(arg, bint allow_stream_protocol=False):
455+
if isinstance(arg, Stream):
462456
return <Stream> arg
463457
elif isinstance(arg, GraphBuilder):
464458
return <Stream> arg.stream

0 commit comments

Comments
 (0)