|
10 | 10 |
|
11 | 11 | CUDA_BINDINGS_NVML_IS_COMPATIBLE: bool |
12 | 12 |
|
| 13 | + |
| 14 | +# POSIX per-thread locale APIs. We use these (rather than setlocale(3)) |
| 15 | +# so the WSL workaround in get_process_name() doesn't perturb the locale |
| 16 | +# observed by other threads. locale_t is an opaque pointer in glibc. |
| 17 | +cdef extern from "locale.h" nogil: |
| 18 | + ctypedef void *locale_t |
| 19 | + int LC_ALL_MASK |
| 20 | + locale_t LC_GLOBAL_LOCALE |
| 21 | + locale_t newlocale(int category_mask, const char *locale, locale_t base) |
| 22 | + locale_t uselocale(locale_t newloc) |
| 23 | + void freelocale(locale_t locobj) |
| 24 | + |
| 25 | + |
| 26 | +cdef bint _detect_wsl(): |
| 27 | + try: |
| 28 | + with open("/proc/sys/kernel/osrelease") as f: |
| 29 | + data = f.read().lower() |
| 30 | + except OSError: |
| 31 | + return False |
| 32 | + return "microsoft" in data or "wsl" in data |
| 33 | + |
| 34 | + |
| 35 | +cdef bint _IS_WSL = _detect_wsl() |
| 36 | + |
13 | 37 | try: |
14 | 38 | from cuda.bindings._version import __version_tuple__ as _BINDINGS_VERSION |
15 | 39 | except ImportError: |
@@ -127,8 +151,43 @@ def get_process_name(pid: int) -> str: |
127 | 151 | name: str |
128 | 152 | The process name. |
129 | 153 | """ |
| 154 | + def _get_process_name(pid) -> str: |
| 155 | + # NVML caches process names on a per-PID basis when queried via |
| 156 | + # nvmlSystemGetProcessName, and the cache is populated when enumerating |
| 157 | + # running processes on devices. To ensure the name is cached for the |
| 158 | + # requested PID, we walk all devices and query their running processes. |
| 159 | + for i in range(nvml.device_get_count_v2()): |
| 160 | + dev_h = nvml.device_get_handle_by_index_v2(i) |
| 161 | + nvml.device_get_compute_running_processes_v3(dev_h) |
| 162 | + return nvml.system_get_process_name(pid) |
| 163 | + |
| 164 | + cdef locale_t c_locale |
| 165 | + cdef locale_t prev_locale |
| 166 | + |
130 | 167 | initialize() |
131 | | - return nvml.system_get_process_name(pid) |
| 168 | + if not _IS_WSL: |
| 169 | + return _get_process_name(pid) |
| 170 | + |
| 171 | + # WSL workaround: nvmlSystemGetProcessName on WSL takes a wide-char |
| 172 | + # conversion path when the process locale is non-"C". That path walks |
| 173 | + # a UTF-16LE source buffer with a 4-byte stride (as if it were UTF-32LE) |
| 174 | + # and emits 5-byte UTF-8 sequences that look like garbage preceding the |
| 175 | + # trailing basename of /proc/<pid>/exe. CPython's startup unconditionally |
| 176 | + # calls setlocale(LC_ALL, ""), so essentially every cuda.core caller hits |
| 177 | + # this. The cached entry for the PID is set the first time NVML resolves |
| 178 | + # it (typically inside nvmlDeviceGetComputeRunningProcesses_v3), so to |
| 179 | + # recover a correct value we re-prime the cache under the "C" locale |
| 180 | + # before reading the name. We use the POSIX per-thread locale APIs so |
| 181 | + # other threads' view of the locale is unaffected. |
| 182 | + c_locale = newlocale(LC_ALL_MASK, b"C", <locale_t>0) |
| 183 | + if c_locale == <locale_t>0: |
| 184 | + raise RuntimeError("Failed to create C locale") |
| 185 | + prev_locale = uselocale(c_locale) |
| 186 | + try: |
| 187 | + return _get_process_name(pid) |
| 188 | + finally: |
| 189 | + uselocale(prev_locale) |
| 190 | + freelocale(c_locale) |
132 | 191 |
|
133 | 192 |
|
134 | 193 | __all__ = [ |
|
0 commit comments