Skip to content

Commit 34ecade

Browse files
committed
test_green_context.py: adapt SM split tests to topology
Probe for supported explicit SM split sizes instead of assuming Hopper+ devices always expose 8-SM partitions, so Thor-like topologies pass without masking real driver errors.
1 parent 326d522 commit 34ecade

1 file changed

Lines changed: 81 additions & 41 deletions

File tree

cuda_core/tests/test_green_context.py

Lines changed: 81 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,58 @@ def fill_kernel(init_cuda):
8282
return mod.get_kernel("fill")
8383

8484

85-
def _safe_two_group_count(sm):
86-
"""Return a safe per-group SM count for a 2-group split.
85+
def _is_invalid_resource_configuration(exc):
86+
return "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(exc)
8787

88-
Uses min_partition_size which is always a valid split size regardless
89-
of hardware topology. Returns None if the device doesn't have enough SMs.
90-
"""
91-
min_size = sm.min_partition_size
92-
if sm.sm_count < 2 * min_size:
93-
return None
94-
return min_size
88+
89+
def _iter_requested_sm_counts(sm, n_groups=1, *, descending=False):
90+
"""Yield even per-group SM counts worth probing on this device."""
91+
start = max(2, sm.min_partition_size)
92+
if start % 2:
93+
start += 1
94+
stop = sm.sm_count // n_groups
95+
counts = range(start, stop + 1, 2)
96+
return reversed(counts) if descending else counts
97+
98+
99+
def _try_sm_split(sm, *, count, backfill=False):
100+
try:
101+
return sm.split(SMResourceOptions(count=count, backfill=backfill))
102+
except CUDAError as exc:
103+
if _is_invalid_resource_configuration(exc):
104+
return None
105+
raise
106+
107+
108+
def _find_supported_split(sm, *, n_groups=1, backfill=False, descending=False):
109+
"""Return a supported explicit split request for this device, if any."""
110+
for count in _iter_requested_sm_counts(sm, n_groups=n_groups, descending=descending):
111+
request = count if n_groups == 1 else (count,) * n_groups
112+
split = _try_sm_split(sm, count=request, backfill=backfill)
113+
if split is not None:
114+
groups, rem = split
115+
return count, groups, rem
116+
return None
117+
118+
119+
def _find_any_two_group_split(sm):
120+
split = _find_supported_split(sm, n_groups=2)
121+
if split is not None:
122+
return split
123+
return _find_supported_split(sm, n_groups=2, backfill=True)
124+
125+
126+
def _find_backfill_only_two_group_split(sm):
127+
"""Return a 2-group split size that needs backfill, if the device has one."""
128+
for count in _iter_requested_sm_counts(sm, n_groups=2, descending=True):
129+
request = (count, count)
130+
if _try_sm_split(sm, count=request) is not None:
131+
continue
132+
split = _try_sm_split(sm, count=request, backfill=True)
133+
if split is not None:
134+
groups, rem = split
135+
return count, groups, rem
136+
return None
95137

96138

97139
@contextlib.contextmanager
@@ -153,8 +195,10 @@ def test_arch_constraints_pre_hopper(self, init_cuda, sm_resource):
153195
def test_arch_constraints_hopper_plus(self, init_cuda, sm_resource):
154196
if init_cuda.compute_capability < (9, 0):
155197
pytest.skip("Test is for Hopper+ architectures")
156-
assert sm_resource.min_partition_size >= 8
157-
assert sm_resource.coscheduled_alignment >= 8
198+
assert sm_resource.min_partition_size >= 2
199+
assert sm_resource.coscheduled_alignment >= 2
200+
assert sm_resource.min_partition_size % 2 == 0
201+
assert sm_resource.coscheduled_alignment % 2 == 0
158202

159203

160204
# ---------------------------------------------------------------------------
@@ -221,9 +265,11 @@ def test_dry_run_cannot_create_context(self, init_cuda, sm_resource):
221265

222266
class TestSMResourceSplit:
223267
def test_single_group_counts(self, sm_resource):
224-
"""Single-group split: group gets at least requested SMs."""
225-
requested = sm_resource.min_partition_size
226-
groups, rem = sm_resource.split(SMResourceOptions(count=requested))
268+
"""Single-group split: group gets at least a supported requested size."""
269+
split = _find_supported_split(sm_resource)
270+
if split is None:
271+
pytest.skip("Device does not expose a valid explicit single-group split")
272+
requested, groups, rem = split
227273

228274
assert len(groups) == 1
229275
assert groups[0].sm_count >= requested
@@ -243,12 +289,11 @@ def test_discovery_respects_alignment(self, sm_resource):
243289
assert groups[0].sm_count % sm_resource.coscheduled_alignment == 0
244290

245291
def test_two_groups(self, sm_resource):
246-
"""Two-group split with min_partition_size (always topology-safe)."""
247-
count = _safe_two_group_count(sm_resource)
248-
if count is None:
249-
pytest.skip("Not enough SMs for a 2-group split")
250-
251-
groups, rem = sm_resource.split(SMResourceOptions(count=(count, count)))
292+
"""Two-group split succeeds for a supported explicit request."""
293+
split = _find_supported_split(sm_resource, n_groups=2)
294+
if split is None:
295+
pytest.skip("Device does not expose a valid 2-group split without backfill")
296+
count, groups, rem = split
252297

253298
assert len(groups) == 2
254299
assert groups[0].sm_count >= count
@@ -257,19 +302,16 @@ def test_two_groups(self, sm_resource):
257302
assert total <= sm_resource.sm_count
258303

259304
def test_two_groups_backfill(self, sm_resource):
260-
"""Two-group split with backfill allows larger partitions."""
261-
align = sm_resource.coscheduled_alignment
262-
if align == 0:
263-
align = sm_resource.min_partition_size
264-
half = (sm_resource.sm_count // 2 // align) * align
265-
if half < sm_resource.min_partition_size:
266-
pytest.skip("Not enough SMs for a 2-group backfill split")
267-
268-
groups, rem = sm_resource.split(SMResourceOptions(count=(half, half), backfill=True))
305+
"""Backfill unlocks a 2-group split size that default placement rejects."""
306+
split = _find_backfill_only_two_group_split(sm_resource)
307+
if split is None:
308+
pytest.skip("Device does not expose a backfill-only 2-group split")
309+
requested, groups, rem = split
269310

270311
assert len(groups) == 2
271-
assert groups[0].sm_count >= half
272-
assert groups[1].sm_count >= half
312+
assert groups[0].sm_count >= requested
313+
assert groups[1].sm_count >= requested
314+
assert groups[0].sm_count + groups[1].sm_count + rem.sm_count <= sm_resource.sm_count
273315

274316
def test_dry_run_matches_real(self, sm_resource):
275317
"""Dry-run reports the same SM counts as a real split."""
@@ -360,11 +402,10 @@ def test_green_ctx_sm_resources(self, green_ctx, sm_resource):
360402

361403
def test_green_ctx_resources_reflect_partition(self, init_cuda, sm_resource):
362404
"""Two green contexts should have disjoint SM partitions."""
363-
count = _safe_two_group_count(sm_resource)
364-
if count is None:
365-
pytest.skip("Not enough SMs for a 2-group split")
366-
367-
groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
405+
split = _find_any_two_group_split(sm_resource)
406+
if split is None:
407+
pytest.skip("Device does not expose a valid 2-group split")
408+
_, groups, _ = split
368409

369410
ctx_a = ctx_b = None
370411
try:
@@ -433,11 +474,10 @@ def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel):
433474
def test_two_green_contexts_independent(self, init_cuda, sm_resource, fill_kernel):
434475
"""Two SM groups -> two green contexts -> two independent kernels."""
435476
dev = init_cuda
436-
count = _safe_two_group_count(sm_resource)
437-
if count is None:
438-
pytest.skip("Not enough SMs for a 2-group split")
439-
440-
groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
477+
split = _find_any_two_group_split(sm_resource)
478+
if split is None:
479+
pytest.skip("Device does not expose a valid 2-group split")
480+
_, groups, _ = split
441481
assert len(groups) == 2
442482

443483
ctx_a = ctx_b = None

0 commit comments

Comments
 (0)