From fb6b7f7a4b77799a4f0b89a3152267ac752d78a2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 9 Jun 2026 20:36:24 -0700 Subject: [PATCH] Fix ROCm GPU arch detection: prefer torch device properties On Windows, get_rocm_gpu_arch() probed hipinfo.exe via PATH only. In practice hipInfo.exe is rarely on PATH: hosts without the HIP SDK do not have it there, and AMD's PyTorch wheels ship hipInfo.exe into the environment's Scripts directory, which is only on PATH when the venv is activated. The probe then raises FileNotFoundError, every import of bitsandbytes logs an ERROR + WARNING, and ROCM_GPU_ARCH silently degrades to unknown. Read torch.cuda.get_device_properties(0).gcnArchName first (works on Linux and Windows, no subprocess); keep the rocminfo / hipInfo.exe parsing as a fallback, additionally trying hipInfo.exe next to python.exe on Windows before giving up. Verified on gfx1151 (Strix Halo, Windows 11, torch 2.11.0+rocm7.13.0): previously unknown + ERROR; now gfx1151 via both the torch path and the forced subprocess fallback. Co-Authored-By: Claude Fable 5 --- bitsandbytes/cuda_specs.py | 58 ++++++++++++++++++++---------- tests/test_cuda_setup_evaluator.py | 23 +++++++++++- 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 25ce3cd1e..dfa571d3c 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,9 +1,11 @@ import dataclasses from functools import lru_cache import logging +import os import platform import re import subprocess +import sys from typing import Optional import torch @@ -82,30 +84,48 @@ def get_cuda_specs() -> Optional[CUDASpecs]: def get_rocm_gpu_arch() -> str: """Get ROCm GPU architecture.""" logger = logging.getLogger(__name__) - try: - if torch.version.hip: - # On Windows, use hipinfo.exe; on Linux, use rocminfo - if platform.system() == "Windows": - cmd = ["hipinfo.exe"] - arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)" - else: - cmd = ["rocminfo"] - arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)" + if not torch.version.hip: + return "unknown" + + # Prefer the architecture torch already knows; this needs no subprocess. + if torch.cuda.is_available(): + try: + # gcnArchName may include feature flags, e.g. "gfx90a:sramecc+:xnack-". + return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] + except Exception as e: + logger.debug(f"Could not get ROCm GPU architecture from torch: {e}") + + # Fall back to parsing tool output. On Windows, use hipInfo.exe; on Linux, use rocminfo. + if platform.system() == "Windows": + # hipInfo.exe is usually not on PATH: the HIP SDK does not add its bin directory, + # and AMD's PyTorch wheels for Windows ship hipInfo.exe next to python.exe instead. + cmds = [ + ["hipinfo.exe"], + [os.path.join(os.path.dirname(sys.executable), "hipInfo.exe")], + ] + arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)" + else: + cmds = [["rocminfo"]] + arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)" + + last_error: Optional[Exception] = None + for cmd in cmds: + try: result = subprocess.run(cmd, capture_output=True, text=True) - match = re.search(arch_pattern, result.stdout) - if match: - return "gfx" + match.group(1) - else: - return "unknown" - else: - return "unknown" - except Exception as e: - logger.error(f"Could not detect ROCm GPU architecture: {e}") + except Exception as e: + last_error = e + continue + match = re.search(arch_pattern, result.stdout) + if match: + return "gfx" + match.group(1) + + if last_error is not None: + logger.error(f"Could not detect ROCm GPU architecture: {last_error}") if torch.cuda.is_available(): logger.warning( """ ROCm GPU architecture detection failed despite ROCm being available. """, ) - return "unknown" + return "unknown" diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index a42b026f7..59f8f24c9 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,7 +1,10 @@ +from types import SimpleNamespace + import pytest +import torch from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.cuda_specs import CUDASpecs, get_rocm_gpu_arch @pytest.fixture @@ -96,3 +99,21 @@ def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spe monkeypatch.setenv("BNB_CUDA_VERSION", "110") with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*not a CUDA build"): get_cuda_bnb_library_path(rocm70_spec) + + +def test_get_rocm_gpu_arch_from_torch(monkeypatch): + """With a visible GPU, the architecture comes from torch device properties, without a subprocess.""" + monkeypatch.setattr(torch.version, "hip", "7.0.0") + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr( + torch.cuda, + "get_device_properties", + lambda device: SimpleNamespace(gcnArchName="gfx90a:sramecc+:xnack-"), + ) + assert get_rocm_gpu_arch() == "gfx90a" + + +def test_get_rocm_gpu_arch_non_rocm(monkeypatch): + """On non-ROCm builds, no detection is attempted.""" + monkeypatch.setattr(torch.version, "hip", None) + assert get_rocm_gpu_arch() == "unknown"