@@ -82,16 +82,58 @@ def fill_kernel(init_cuda):
8282 return mod .get_kernel ("fill" )
8383
8484
85- def _safe_two_group_count ( sm ):
86- """Return a safe per-group SM count for a 2-group split.
85+ def _is_invalid_resource_configuration ( exc ):
86+ return "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str ( exc )
8787
88- Uses min_partition_size which is always a valid split size regardless
89- of hardware topology. Returns None if the device doesn't have enough SMs.
90- """
91- min_size = sm .min_partition_size
92- if sm .sm_count < 2 * min_size :
93- return None
94- return min_size
88+
89+ def _iter_requested_sm_counts (sm , n_groups = 1 , * , descending = False ):
90+ """Yield even per-group SM counts worth probing on this device."""
91+ start = max (2 , sm .min_partition_size )
92+ if start % 2 :
93+ start += 1
94+ stop = sm .sm_count // n_groups
95+ counts = range (start , stop + 1 , 2 )
96+ return reversed (counts ) if descending else counts
97+
98+
99+ def _try_sm_split (sm , * , count , backfill = False ):
100+ try :
101+ return sm .split (SMResourceOptions (count = count , backfill = backfill ))
102+ except CUDAError as exc :
103+ if _is_invalid_resource_configuration (exc ):
104+ return None
105+ raise
106+
107+
108+ def _find_supported_split (sm , * , n_groups = 1 , backfill = False , descending = False ):
109+ """Return a supported explicit split request for this device, if any."""
110+ for count in _iter_requested_sm_counts (sm , n_groups = n_groups , descending = descending ):
111+ request = count if n_groups == 1 else (count ,) * n_groups
112+ split = _try_sm_split (sm , count = request , backfill = backfill )
113+ if split is not None :
114+ groups , rem = split
115+ return count , groups , rem
116+ return None
117+
118+
119+ def _find_any_two_group_split (sm ):
120+ split = _find_supported_split (sm , n_groups = 2 )
121+ if split is not None :
122+ return split
123+ return _find_supported_split (sm , n_groups = 2 , backfill = True )
124+
125+
126+ def _find_backfill_only_two_group_split (sm ):
127+ """Return a 2-group split size that needs backfill, if the device has one."""
128+ for count in _iter_requested_sm_counts (sm , n_groups = 2 , descending = True ):
129+ request = (count , count )
130+ if _try_sm_split (sm , count = request ) is not None :
131+ continue
132+ split = _try_sm_split (sm , count = request , backfill = True )
133+ if split is not None :
134+ groups , rem = split
135+ return count , groups , rem
136+ return None
95137
96138
97139@contextlib .contextmanager
@@ -153,8 +195,10 @@ def test_arch_constraints_pre_hopper(self, init_cuda, sm_resource):
153195 def test_arch_constraints_hopper_plus (self , init_cuda , sm_resource ):
154196 if init_cuda .compute_capability < (9 , 0 ):
155197 pytest .skip ("Test is for Hopper+ architectures" )
156- assert sm_resource .min_partition_size >= 8
157- assert sm_resource .coscheduled_alignment >= 8
198+ assert sm_resource .min_partition_size >= 2
199+ assert sm_resource .coscheduled_alignment >= 2
200+ assert sm_resource .min_partition_size % 2 == 0
201+ assert sm_resource .coscheduled_alignment % 2 == 0
158202
159203
160204# ---------------------------------------------------------------------------
@@ -221,9 +265,11 @@ def test_dry_run_cannot_create_context(self, init_cuda, sm_resource):
221265
222266class TestSMResourceSplit :
223267 def test_single_group_counts (self , sm_resource ):
224- """Single-group split: group gets at least requested SMs."""
225- requested = sm_resource .min_partition_size
226- groups , rem = sm_resource .split (SMResourceOptions (count = requested ))
268+ """Single-group split: group gets at least a supported requested size."""
269+ split = _find_supported_split (sm_resource )
270+ if split is None :
271+ pytest .skip ("Device does not expose a valid explicit single-group split" )
272+ requested , groups , rem = split
227273
228274 assert len (groups ) == 1
229275 assert groups [0 ].sm_count >= requested
@@ -243,12 +289,11 @@ def test_discovery_respects_alignment(self, sm_resource):
243289 assert groups [0 ].sm_count % sm_resource .coscheduled_alignment == 0
244290
245291 def test_two_groups (self , sm_resource ):
246- """Two-group split with min_partition_size (always topology-safe)."""
247- count = _safe_two_group_count (sm_resource )
248- if count is None :
249- pytest .skip ("Not enough SMs for a 2-group split" )
250-
251- groups , rem = sm_resource .split (SMResourceOptions (count = (count , count )))
292+ """Two-group split succeeds for a supported explicit request."""
293+ split = _find_supported_split (sm_resource , n_groups = 2 )
294+ if split is None :
295+ pytest .skip ("Device does not expose a valid 2-group split without backfill" )
296+ count , groups , rem = split
252297
253298 assert len (groups ) == 2
254299 assert groups [0 ].sm_count >= count
@@ -257,19 +302,16 @@ def test_two_groups(self, sm_resource):
257302 assert total <= sm_resource .sm_count
258303
259304 def test_two_groups_backfill (self , sm_resource ):
260- """Two-group split with backfill allows larger partitions."""
261- align = sm_resource .coscheduled_alignment
262- if align == 0 :
263- align = sm_resource .min_partition_size
264- half = (sm_resource .sm_count // 2 // align ) * align
265- if half < sm_resource .min_partition_size :
266- pytest .skip ("Not enough SMs for a 2-group backfill split" )
267-
268- groups , rem = sm_resource .split (SMResourceOptions (count = (half , half ), backfill = True ))
305+ """Backfill unlocks a 2-group split size that default placement rejects."""
306+ split = _find_backfill_only_two_group_split (sm_resource )
307+ if split is None :
308+ pytest .skip ("Device does not expose a backfill-only 2-group split" )
309+ requested , groups , rem = split
269310
270311 assert len (groups ) == 2
271- assert groups [0 ].sm_count >= half
272- assert groups [1 ].sm_count >= half
312+ assert groups [0 ].sm_count >= requested
313+ assert groups [1 ].sm_count >= requested
314+ assert groups [0 ].sm_count + groups [1 ].sm_count + rem .sm_count <= sm_resource .sm_count
273315
274316 def test_dry_run_matches_real (self , sm_resource ):
275317 """Dry-run reports the same SM counts as a real split."""
@@ -360,11 +402,10 @@ def test_green_ctx_sm_resources(self, green_ctx, sm_resource):
360402
361403 def test_green_ctx_resources_reflect_partition (self , init_cuda , sm_resource ):
362404 """Two green contexts should have disjoint SM partitions."""
363- count = _safe_two_group_count (sm_resource )
364- if count is None :
365- pytest .skip ("Not enough SMs for a 2-group split" )
366-
367- groups , _ = sm_resource .split (SMResourceOptions (count = (count , count )))
405+ split = _find_any_two_group_split (sm_resource )
406+ if split is None :
407+ pytest .skip ("Device does not expose a valid 2-group split" )
408+ _ , groups , _ = split
368409
369410 ctx_a = ctx_b = None
370411 try :
@@ -433,11 +474,10 @@ def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel):
433474 def test_two_green_contexts_independent (self , init_cuda , sm_resource , fill_kernel ):
434475 """Two SM groups -> two green contexts -> two independent kernels."""
435476 dev = init_cuda
436- count = _safe_two_group_count (sm_resource )
437- if count is None :
438- pytest .skip ("Not enough SMs for a 2-group split" )
439-
440- groups , _ = sm_resource .split (SMResourceOptions (count = (count , count )))
477+ split = _find_any_two_group_split (sm_resource )
478+ if split is None :
479+ pytest .skip ("Device does not expose a valid 2-group split" )
480+ _ , groups , _ = split
441481 assert len (groups ) == 2
442482
443483 ctx_a = ctx_b = None
0 commit comments