Skip to content

Commit cb7d340

Browse files
committed
Improve docs and exception-handling
1 parent 5f15b4c commit cb7d340

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

cuda_core/cuda/core/system/_device.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,15 @@ cdef class Device:
376376
-------
377377
cuda.core.Device
378378
The corresponding CUDA device.
379+
380+
Raises
381+
------
382+
RuntimeError
383+
No corresponding CUDA device is found for this NVML device.
384+
385+
For example, on a MIG system, the physical GPU will not have an
386+
available CUDA device, since it can not be used directly, even
387+
though it can be enumerated from NVML.
379388
"""
380389
from cuda.core import Device as CudaDevice
381390

@@ -890,8 +899,16 @@ cdef class Device:
890899
def pci_info(self) -> PciInfo:
891900
"""
892901
:obj:`~_device.PciInfo` object with the PCI attributes of this device.
902+
903+
Non-physical devices, such as MIG devices, may not have PCI attributes.
904+
In that case, this property will raise a `RuntimeError`.
893905
"""
894-
return PciInfo(nvml.device_get_pci_info_ext(self._handle), self._handle)
906+
try:
907+
pci_info = nvml.device_get_pci_info_ext(self._handle)
908+
except nvml.InvalidArgumentError:
909+
raise RuntimeError("This device does not have PCI attributes") from None
910+
else:
911+
return PciInfo(pci_info, self._handle)
895912
896913
##########################################################################
897914
# PERFORMANCE

cuda_core/tests/test_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_to_system_device(deinit_cuda):
4848
# CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
4949
# 4-byte domain
5050
# MIG devices don't have pci_info, so skip the bus ID check if it's missing
51-
with contextlib.suppress(_system.InvalidArgumentError):
51+
with contextlib.suppress(RuntimeError):
5252
assert device.pci_bus_id == system_device.pci_info.bus_id[4:]
5353

5454

0 commit comments

Comments
 (0)