@@ -49,6 +49,7 @@ class _OptixState:
4949 """
5050
5151 def __init__ (self ):
52+ self .device_id = None # CUDA device ID used for this context
5253 self .context = None
5354 self .module = None
5455 self .pipeline = None
@@ -83,6 +84,9 @@ def __init__(self):
8384
8485 def cleanup (self ):
8586 """Release all OptiX and CUDA resources."""
87+ # Reset device tracking
88+ self .device_id = None
89+
8690 # Free device buffers
8791 self .d_params = None
8892 self .d_rays = None
@@ -114,6 +118,10 @@ def cleanup(self):
114118
115119 self .initialized = False
116120
121+ def reset_device (self ):
122+ """Reset device tracking (called during cleanup)."""
123+ self .device_id = None
124+
117125
118126_state = _OptixState ()
119127
@@ -125,6 +133,106 @@ def _cleanup_at_exit():
125133 _state .cleanup ()
126134
127135
136+ # -----------------------------------------------------------------------------
137+ # Device utilities
138+ # -----------------------------------------------------------------------------
139+
140+ def get_device_count () -> int :
141+ """
142+ Get the number of available CUDA devices.
143+
144+ Returns:
145+ Number of CUDA-capable GPUs available.
146+
147+ Example:
148+ >>> import rtxpy
149+ >>> rtxpy.get_device_count()
150+ 2
151+ """
152+ return cupy .cuda .runtime .getDeviceCount ()
153+
154+
155+ def get_device_properties (device : int = 0 ) -> dict :
156+ """
157+ Get properties of a CUDA device.
158+
159+ Args:
160+ device: Device ID (0, 1, 2, ...). Defaults to device 0.
161+
162+ Returns:
163+ Dictionary containing device properties including:
164+ - name: Device name (e.g., "NVIDIA GeForce RTX 3090")
165+ - compute_capability: Tuple of (major, minor) compute capability
166+ - total_memory: Total device memory in bytes
167+ - multiprocessor_count: Number of streaming multiprocessors
168+
169+ Raises:
170+ ValueError: If device ID is invalid.
171+
172+ Example:
173+ >>> import rtxpy
174+ >>> props = rtxpy.get_device_properties(0)
175+ >>> print(props['name'])
176+ NVIDIA GeForce RTX 3090
177+ """
178+ device_count = cupy .cuda .runtime .getDeviceCount ()
179+ if device < 0 or device >= device_count :
180+ raise ValueError (
181+ f"Invalid device ID { device } . "
182+ f"Available devices: 0-{ device_count - 1 } "
183+ )
184+
185+ with cupy .cuda .Device (device ):
186+ props = cupy .cuda .runtime .getDeviceProperties (device )
187+
188+ return {
189+ 'name' : props ['name' ].decode ('utf-8' ) if isinstance (props ['name' ], bytes ) else props ['name' ],
190+ 'compute_capability' : (props ['major' ], props ['minor' ]),
191+ 'total_memory' : props ['totalGlobalMem' ],
192+ 'multiprocessor_count' : props ['multiProcessorCount' ],
193+ }
194+
195+
196+ def list_devices () -> list :
197+ """
198+ List all available CUDA devices with their properties.
199+
200+ Returns:
201+ List of dictionaries, each containing device properties.
202+ Each dict includes 'id' (device index) plus all properties
203+ from get_device_properties().
204+
205+ Example:
206+ >>> import rtxpy
207+ >>> for dev in rtxpy.list_devices():
208+ ... print(f"GPU {dev['id']}: {dev['name']}")
209+ GPU 0: NVIDIA GeForce RTX 3090
210+ GPU 1: NVIDIA GeForce RTX 2080
211+ """
212+ devices = []
213+ for i in range (get_device_count ()):
214+ props = get_device_properties (i )
215+ props ['id' ] = i
216+ devices .append (props )
217+ return devices
218+
219+
220+ def get_current_device () -> Optional [int ]:
221+ """
222+ Get the CUDA device ID that RTX is currently using.
223+
224+ Returns:
225+ Device ID if RTX has been initialized, None otherwise.
226+
227+ Example:
228+ >>> import rtxpy
229+ >>> rtx = rtxpy.RTX(device=1)
230+ >>> rtxpy.get_current_device()
231+ 1
232+ """
233+ return _state .device_id if _state .initialized else None
234+
235+
128236# -----------------------------------------------------------------------------
129237# PTX loading
130238# -----------------------------------------------------------------------------
@@ -157,13 +265,43 @@ def _log_callback(level, tag, message):
157265 print (f"[OPTIX][{ level } ][{ tag } ]: { message } " )
158266
159267
160- def _init_optix ():
161- """Initialize OptiX context, module, pipeline, and SBT."""
268+ def _init_optix (device : Optional [int ] = None ):
269+ """
270+ Initialize OptiX context, module, pipeline, and SBT.
271+
272+ Args:
273+ device: CUDA device ID to use. If None, uses the current CuPy device.
274+ If already initialized, this parameter is ignored (a warning
275+ would be appropriate if it differs from the active device).
276+ """
162277 global _state
163278
164279 if _state .initialized :
280+ # Already initialized - check if user requested a different device
281+ if device is not None and _state .device_id != device :
282+ import warnings
283+ warnings .warn (
284+ f"RTX already initialized on device { _state .device_id } . "
285+ f"Ignoring request for device { device } . "
286+ "Create a new Python process to use a different device." ,
287+ RuntimeWarning
288+ )
165289 return
166290
291+ # Select the CUDA device if specified
292+ if device is not None :
293+ device_count = cupy .cuda .runtime .getDeviceCount ()
294+ if device < 0 or device >= device_count :
295+ raise ValueError (
296+ f"Invalid device ID { device } . "
297+ f"Available devices: 0-{ device_count - 1 } "
298+ )
299+ cupy .cuda .Device (device ).use ()
300+ _state .device_id = device
301+ else :
302+ # Use current device
303+ _state .device_id = cupy .cuda .Device ().id
304+
167305 # Create OptiX device context (uses cupy's CUDA context)
168306 _state .context = optix .deviceContextCreate (
169307 cupy .cuda .get_current_stream ().ptr ,
@@ -736,11 +874,34 @@ class RTX:
736874
737875 This class provides GPU-accelerated ray-triangle intersection using
738876 NVIDIA's OptiX ray tracing engine.
877+
878+ Args:
879+ device: CUDA device ID to use (0, 1, 2, ...). If None (default),
880+ uses the currently active CuPy device. Use get_device_count()
881+ to see available devices.
882+
883+ Example:
884+ # Use default device (device 0 or current CuPy device)
885+ rtx = RTX()
886+
887+ # Use specific GPU
888+ rtx = RTX(device=1)
889+
890+ Note:
891+ The RTX context is a singleton - all RTX instances share the same
892+ underlying OptiX context. The device can only be set on first
893+ initialization. Subsequent RTX() calls with a different device
894+ will emit a warning.
739895 """
740896
741- def __init__ (self ):
742- """Initialize the RTX context."""
743- _init_optix ()
897+ def __init__ (self , device : Optional [int ] = None ):
898+ """
899+ Initialize the RTX context.
900+
901+ Args:
902+ device: CUDA device ID to use. If None, uses the current device.
903+ """
904+ _init_optix (device )
744905
745906 def build (self , hashValue : int , vertexBuffer , indexBuffer ) -> int :
746907 """
@@ -756,6 +917,16 @@ def build(self, hashValue: int, vertexBuffer, indexBuffer) -> int:
756917 """
757918 return _build_accel (hashValue , vertexBuffer , indexBuffer )
758919
920+ @property
921+ def device (self ) -> Optional [int ]:
922+ """
923+ The CUDA device ID this RTX instance is using.
924+
925+ Returns:
926+ Device ID (0, 1, 2, ...) or None if not initialized.
927+ """
928+ return _state .device_id
929+
759930 def getHash (self ) -> int :
760931 """
761932 Get the hash of the current acceleration structure.
0 commit comments