@@ -34,8 +34,7 @@ def _common_kernels_alloc():
3434 }
3535 }
3636 """
37- arch = "" .join (f"{ i } " for i in Device ().compute_capability )
38- program_options = ProgramOptions (std = "c++17" , arch = f"sm_{ arch } " )
37+ program_options = ProgramOptions (std = "c++17" , arch = f"sm_{ Device ().arch } " )
3938 prog = Program (code , code_type = "c++" , options = program_options )
4039 mod = prog .compile ("cubin" , name_expressions = ("set_zero" , "add_one" ))
4140 return mod
@@ -76,10 +75,10 @@ def free(self, buffers):
7675
7776
7877@pytest .mark .parametrize ("mode" , ["no_graph" , "global" , "thread_local" , "relaxed" ])
79- def test_graph_alloc (init_cuda , mode ):
78+ def test_graph_alloc (mempool_device , mode ):
8079 """Test basic graph capture with memory allocated and deallocated by GraphMemoryResource."""
8180 NBYTES = 64
82- device = Device ()
81+ device = mempool_device
8382 stream = device .create_stream ()
8483 dmr = DeviceMemoryResource (device )
8584 gmr = GraphMemoryResource (device )
@@ -118,10 +117,10 @@ def apply_kernels(mr, stream, out):
118117
119118@pytest .mark .skipif (IS_WINDOWS or IS_WSL , reason = "auto_free_on_launch not supported on Windows" )
120119@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
121- def test_graph_alloc_with_output (init_cuda , mode ):
120+ def test_graph_alloc_with_output (mempool_device , mode ):
122121 """Test for memory allocated in a graph being used outside the graph."""
123122 NBYTES = 64
124- device = Device ()
123+ device = mempool_device
125124 stream = device .create_stream ()
126125 gmr = GraphMemoryResource (device )
127126
@@ -157,8 +156,8 @@ def test_graph_alloc_with_output(init_cuda, mode):
157156
158157
159158@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
160- def test_graph_mem_set_attributes (init_cuda , mode ):
161- device = Device ()
159+ def test_graph_mem_set_attributes (mempool_device , mode ):
160+ device = mempool_device
162161 stream = device .create_stream ()
163162 gmr = GraphMemoryResource (device )
164163 mman = GraphMemoryTestManager (gmr , stream , mode )
@@ -209,12 +208,12 @@ def test_graph_mem_set_attributes(init_cuda, mode):
209208
210209
211210@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
212- def test_gmr_check_capture_state (init_cuda , mode ):
211+ def test_gmr_check_capture_state (mempool_device , mode ):
213212 """
214213 Test expected errors (and non-errors) using GraphMemoryResource with graph
215214 capture.
216215 """
217- device = Device ()
216+ device = mempool_device
218217 stream = device .create_stream ()
219218 gmr = GraphMemoryResource (device )
220219
@@ -233,12 +232,12 @@ def test_gmr_check_capture_state(init_cuda, mode):
233232
234233
235234@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
236- def test_dmr_check_capture_state (init_cuda , mode ):
235+ def test_dmr_check_capture_state (mempool_device , mode ):
237236 """
238237 Test expected errors (and non-errors) using DeviceMemoryResource with graph
239238 capture.
240239 """
241- device = Device ()
240+ device = mempool_device
242241 stream = device .create_stream ()
243242 dmr = DeviceMemoryResource (device )
244243
0 commit comments