Skip to content

Commit 25ae3aa

Browse files
authored
added device_id as param to init() to support gpu selection (#33)
1 parent e7513ae commit 25ae3aa

2 files changed

Lines changed: 184 additions & 6 deletions

File tree

rtxpy/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from .rtx import RTX, has_cupy
1+
from .rtx import (
2+
RTX,
3+
has_cupy,
4+
get_device_count,
5+
get_device_properties,
6+
list_devices,
7+
get_current_device,
8+
)
29

310

411
__version__ = "0.0.4"

rtxpy/rtx.py

Lines changed: 176 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class _OptixState:
4949
"""
5050

5151
def __init__(self):
52+
self.device_id = None # CUDA device ID used for this context
5253
self.context = None
5354
self.module = None
5455
self.pipeline = None
@@ -83,6 +84,9 @@ def __init__(self):
8384

8485
def cleanup(self):
8586
"""Release all OptiX and CUDA resources."""
87+
# Reset device tracking
88+
self.device_id = None
89+
8690
# Free device buffers
8791
self.d_params = None
8892
self.d_rays = None
@@ -114,6 +118,10 @@ def cleanup(self):
114118

115119
self.initialized = False
116120

121+
def reset_device(self):
122+
"""Reset device tracking (called during cleanup)."""
123+
self.device_id = None
124+
117125

118126
_state = _OptixState()
119127

@@ -125,6 +133,106 @@ def _cleanup_at_exit():
125133
_state.cleanup()
126134

127135

136+
# -----------------------------------------------------------------------------
137+
# Device utilities
138+
# -----------------------------------------------------------------------------
139+
140+
def get_device_count() -> int:
141+
"""
142+
Get the number of available CUDA devices.
143+
144+
Returns:
145+
Number of CUDA-capable GPUs available.
146+
147+
Example:
148+
>>> import rtxpy
149+
>>> rtxpy.get_device_count()
150+
2
151+
"""
152+
return cupy.cuda.runtime.getDeviceCount()
153+
154+
155+
def get_device_properties(device: int = 0) -> dict:
156+
"""
157+
Get properties of a CUDA device.
158+
159+
Args:
160+
device: Device ID (0, 1, 2, ...). Defaults to device 0.
161+
162+
Returns:
163+
Dictionary containing device properties including:
164+
- name: Device name (e.g., "NVIDIA GeForce RTX 3090")
165+
- compute_capability: Tuple of (major, minor) compute capability
166+
- total_memory: Total device memory in bytes
167+
- multiprocessor_count: Number of streaming multiprocessors
168+
169+
Raises:
170+
ValueError: If device ID is invalid.
171+
172+
Example:
173+
>>> import rtxpy
174+
>>> props = rtxpy.get_device_properties(0)
175+
>>> print(props['name'])
176+
NVIDIA GeForce RTX 3090
177+
"""
178+
device_count = cupy.cuda.runtime.getDeviceCount()
179+
if device < 0 or device >= device_count:
180+
raise ValueError(
181+
f"Invalid device ID {device}. "
182+
f"Available devices: 0-{device_count - 1}"
183+
)
184+
185+
with cupy.cuda.Device(device):
186+
props = cupy.cuda.runtime.getDeviceProperties(device)
187+
188+
return {
189+
'name': props['name'].decode('utf-8') if isinstance(props['name'], bytes) else props['name'],
190+
'compute_capability': (props['major'], props['minor']),
191+
'total_memory': props['totalGlobalMem'],
192+
'multiprocessor_count': props['multiProcessorCount'],
193+
}
194+
195+
196+
def list_devices() -> list:
197+
"""
198+
List all available CUDA devices with their properties.
199+
200+
Returns:
201+
List of dictionaries, each containing device properties.
202+
Each dict includes 'id' (device index) plus all properties
203+
from get_device_properties().
204+
205+
Example:
206+
>>> import rtxpy
207+
>>> for dev in rtxpy.list_devices():
208+
... print(f"GPU {dev['id']}: {dev['name']}")
209+
GPU 0: NVIDIA GeForce RTX 3090
210+
GPU 1: NVIDIA GeForce RTX 2080
211+
"""
212+
devices = []
213+
for i in range(get_device_count()):
214+
props = get_device_properties(i)
215+
props['id'] = i
216+
devices.append(props)
217+
return devices
218+
219+
220+
def get_current_device() -> Optional[int]:
221+
"""
222+
Get the CUDA device ID that RTX is currently using.
223+
224+
Returns:
225+
Device ID if RTX has been initialized, None otherwise.
226+
227+
Example:
228+
>>> import rtxpy
229+
>>> rtx = rtxpy.RTX(device=1)
230+
>>> rtxpy.get_current_device()
231+
1
232+
"""
233+
return _state.device_id if _state.initialized else None
234+
235+
128236
# -----------------------------------------------------------------------------
129237
# PTX loading
130238
# -----------------------------------------------------------------------------
@@ -157,13 +265,43 @@ def _log_callback(level, tag, message):
157265
print(f"[OPTIX][{level}][{tag}]: {message}")
158266

159267

160-
def _init_optix():
161-
"""Initialize OptiX context, module, pipeline, and SBT."""
268+
def _init_optix(device: Optional[int] = None):
269+
"""
270+
Initialize OptiX context, module, pipeline, and SBT.
271+
272+
Args:
273+
device: CUDA device ID to use. If None, uses the current CuPy device.
274+
If already initialized, this parameter is ignored (a warning
275+
would be appropriate if it differs from the active device).
276+
"""
162277
global _state
163278

164279
if _state.initialized:
280+
# Already initialized - check if user requested a different device
281+
if device is not None and _state.device_id != device:
282+
import warnings
283+
warnings.warn(
284+
f"RTX already initialized on device {_state.device_id}. "
285+
f"Ignoring request for device {device}. "
286+
"Create a new Python process to use a different device.",
287+
RuntimeWarning
288+
)
165289
return
166290

291+
# Select the CUDA device if specified
292+
if device is not None:
293+
device_count = cupy.cuda.runtime.getDeviceCount()
294+
if device < 0 or device >= device_count:
295+
raise ValueError(
296+
f"Invalid device ID {device}. "
297+
f"Available devices: 0-{device_count - 1}"
298+
)
299+
cupy.cuda.Device(device).use()
300+
_state.device_id = device
301+
else:
302+
# Use current device
303+
_state.device_id = cupy.cuda.Device().id
304+
167305
# Create OptiX device context (uses cupy's CUDA context)
168306
_state.context = optix.deviceContextCreate(
169307
cupy.cuda.get_current_stream().ptr,
@@ -736,11 +874,34 @@ class RTX:
736874
737875
This class provides GPU-accelerated ray-triangle intersection using
738876
NVIDIA's OptiX ray tracing engine.
877+
878+
Args:
879+
device: CUDA device ID to use (0, 1, 2, ...). If None (default),
880+
uses the currently active CuPy device. Use get_device_count()
881+
to see available devices.
882+
883+
Example:
884+
# Use default device (device 0 or current CuPy device)
885+
rtx = RTX()
886+
887+
# Use specific GPU
888+
rtx = RTX(device=1)
889+
890+
Note:
891+
The RTX context is a singleton - all RTX instances share the same
892+
underlying OptiX context. The device can only be set on first
893+
initialization. Subsequent RTX() calls with a different device
894+
will emit a warning.
739895
"""
740896

741-
def __init__(self):
742-
"""Initialize the RTX context."""
743-
_init_optix()
897+
def __init__(self, device: Optional[int] = None):
898+
"""
899+
Initialize the RTX context.
900+
901+
Args:
902+
device: CUDA device ID to use. If None, uses the current device.
903+
"""
904+
_init_optix(device)
744905

745906
def build(self, hashValue: int, vertexBuffer, indexBuffer) -> int:
746907
"""
@@ -756,6 +917,16 @@ def build(self, hashValue: int, vertexBuffer, indexBuffer) -> int:
756917
"""
757918
return _build_accel(hashValue, vertexBuffer, indexBuffer)
758919

920+
@property
921+
def device(self) -> Optional[int]:
922+
"""
923+
The CUDA device ID this RTX instance is using.
924+
925+
Returns:
926+
Device ID (0, 1, 2, ...) or None if not initialized.
927+
"""
928+
return _state.device_id
929+
759930
def getHash(self) -> int:
760931
"""
761932
Get the hash of the current acceleration structure.

0 commit comments

Comments
 (0)