Skip to content

Commit 6654645

Browse files
cuda.core: keep kernel-argument objects alive in graph kernel nodes
`GraphDefinition.launch()` did not extend the lifetime of the Python kernel-argument objects to the lifetime of the graph. The `ParamHolder` built in `GN_launch` held the only references to those objects and was destroyed when `GN_launch` returned. The driver only stores the raw pointer values in the kernel node, so a `Buffer` reachable only through the call could be GC'd before the graph ran, leaving the graph with a stale device pointer. Attach the `kernel_args` tuple to the graph as a CUDA user object, mirroring the existing handling of `KernelHandle` and `EventHandle`. This reuses the `_py_host_destructor` path already used by the host callback machinery. Closes #2039 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 50c19d0 commit 6654645

3 files changed

Lines changed: 89 additions & 1 deletion

File tree

cuda_core/cuda/core/graph/_graph_node.pyx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import annotations
88

9+
from cpython.ref cimport Py_INCREF
10+
911
from libc.stddef cimport size_t
1012
from libc.stdint cimport uintptr_t
1113
from libc.string cimport memset as c_memset
@@ -54,6 +56,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value
5456
from cuda.core.graph._utils cimport (
5557
_attach_host_callback_to_graph,
5658
_attach_user_object,
59+
_py_host_destructor,
5760
)
5861

5962
import weakref
@@ -617,6 +620,12 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
617620
_attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel),
618621
<cydriver.CUhostFn>_destroy_kernel_handle_copy)
619622

623+
cdef object kernel_args = ker_args.kernel_args
624+
if kernel_args is not None:
625+
Py_INCREF(kernel_args)
626+
_attach_user_object(as_cu(h_graph), <void*>kernel_args,
627+
<cydriver.CUhostFn>_py_host_destructor)
628+
620629
return _registered(KernelNode._create_with_params(
621630
create_graph_node_handle(new_node, h_graph),
622631
conf.grid, conf.block, conf.shmem_size,

cuda_core/cuda/core/graph/_utils.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ from cuda.bindings cimport cydriver
77

88
cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil
99

10+
cdef void _py_host_destructor(void* data) noexcept with gil
11+
1012
cdef void _attach_user_object(
1113
cydriver.CUgraph graph, void* ptr,
1214
cydriver.CUhostFn destroy) except *

cuda_core/tests/graph/test_graph_definition_lifetime.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
"""Tests for GraphDefinition resource lifetime management and RAII correctness."""
55

6+
import ctypes
67
import gc
8+
import weakref
79

810
import pytest
911
from helpers.graph_kernels import compile_common_kernels
1012
from helpers.misc import try_create_condition
1113

12-
from cuda.core import Device, EventOptions, Kernel, LaunchConfig
14+
from cuda.core import Device, DeviceMemoryResource, EventOptions, Kernel, LaunchConfig
1315
from cuda.core.graph import (
1416
ChildGraphNode,
1517
ConditionalNode,
@@ -485,3 +487,78 @@ def test_kernel_node_reconstruction_preserves_validity(init_cuda):
485487
stream = Device().create_stream()
486488
graph.launch(stream)
487489
stream.sync()
490+
491+
492+
# =============================================================================
493+
# Kernel argument lifetime — kernel nodes should keep argument objects alive
494+
# =============================================================================
495+
496+
497+
def test_kernel_args_buffer_kept_alive_through_execution(init_cuda):
498+
"""Buffer passed as a kernel arg is kept alive by the graph, and the kernel
499+
actually executes against its memory after the original Python ref drops.
500+
501+
Without the user-object attachment, the ParamHolder is destroyed when the
502+
kernel node is added, the Buffer is GC'd, and the graph is left with a
503+
stale device pointer.
504+
"""
505+
from cuda.core._utils.cuda_utils import driver, handle_return
506+
507+
_skip_if_no_mempool()
508+
dev = Device()
509+
mr = DeviceMemoryResource(dev)
510+
add_one = compile_common_kernels().get_kernel("add_one")
511+
buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream)
512+
buf.fill(0, stream=dev.default_stream)
513+
dev.default_stream.sync()
514+
buf_weak = weakref.ref(buf)
515+
dptr = int(buf.handle)
516+
517+
g = GraphDefinition()
518+
g.launch(LaunchConfig(grid=1, block=1), add_one, buf)
519+
520+
del buf
521+
gc.collect()
522+
assert buf_weak() is not None # graph kept the Buffer alive
523+
524+
stream = dev.create_stream()
525+
g.instantiate().launch(stream)
526+
stream.sync()
527+
528+
out = (ctypes.c_int * 1)(0)
529+
handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int)))
530+
assert out[0] == 1
531+
532+
533+
def test_kernel_args_survive_graph_clone(init_cuda):
534+
"""Cloned graph keeps Buffer alive via CUDA user objects.
535+
536+
A graph clone does not inherit Python-level references, so only user
537+
objects (which propagate through cuGraphClone) can keep the args alive.
538+
"""
539+
from cuda.core._utils.cuda_utils import driver, handle_return
540+
541+
_skip_if_no_mempool()
542+
dev = Device()
543+
mr = DeviceMemoryResource(dev)
544+
add_one = compile_common_kernels().get_kernel("add_one")
545+
buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream)
546+
buf.fill(0, stream=dev.default_stream)
547+
dev.default_stream.sync()
548+
dptr = int(buf.handle)
549+
550+
g = GraphDefinition()
551+
g.launch(LaunchConfig(grid=1, block=1), add_one, buf)
552+
cloned_cu_graph = handle_return(driver.cuGraphClone(driver.CUgraph(g.handle)))
553+
554+
del buf, g
555+
gc.collect()
556+
557+
graph_exec = handle_return(driver.cuGraphInstantiate(cloned_cu_graph, 0))
558+
stream = dev.create_stream()
559+
handle_return(driver.cuGraphLaunch(graph_exec, driver.CUstream(int(stream.handle))))
560+
stream.sync()
561+
562+
out = (ctypes.c_int * 1)(0)
563+
handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int)))
564+
assert out[0] == 1

0 commit comments

Comments
 (0)