diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 25ce3cd1e..dfa571d3c 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,9 +1,11 @@ import dataclasses from functools import lru_cache import logging +import os import platform import re import subprocess +import sys from typing import Optional import torch @@ -82,30 +84,48 @@ def get_cuda_specs() -> Optional[CUDASpecs]: def get_rocm_gpu_arch() -> str: """Get ROCm GPU architecture.""" logger = logging.getLogger(__name__) - try: - if torch.version.hip: - # On Windows, use hipinfo.exe; on Linux, use rocminfo - if platform.system() == "Windows": - cmd = ["hipinfo.exe"] - arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)" - else: - cmd = ["rocminfo"] - arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)" + if not torch.version.hip: + return "unknown" + + # Prefer the architecture torch already knows; this needs no subprocess. + if torch.cuda.is_available(): + try: + # gcnArchName may include feature flags, e.g. "gfx90a:sramecc+:xnack-". + return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] + except Exception as e: + logger.debug(f"Could not get ROCm GPU architecture from torch: {e}") + + # Fall back to parsing tool output. On Windows, use hipInfo.exe; on Linux, use rocminfo. + if platform.system() == "Windows": + # hipInfo.exe is usually not on PATH: the HIP SDK does not add its bin directory, + # and AMD's PyTorch wheels for Windows ship hipInfo.exe next to python.exe instead. + cmds = [ + ["hipinfo.exe"], + [os.path.join(os.path.dirname(sys.executable), "hipInfo.exe")], + ] + arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)" + else: + cmds = [["rocminfo"]] + arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)" + + last_error: Optional[Exception] = None + for cmd in cmds: + try: result = subprocess.run(cmd, capture_output=True, text=True) - match = re.search(arch_pattern, result.stdout) - if match: - return "gfx" + match.group(1) - else: - return "unknown" - else: - return "unknown" - except Exception as e: - logger.error(f"Could not detect ROCm GPU architecture: {e}") + except Exception as e: + last_error = e + continue + match = re.search(arch_pattern, result.stdout) + if match: + return "gfx" + match.group(1) + + if last_error is not None: + logger.error(f"Could not detect ROCm GPU architecture: {last_error}") if torch.cuda.is_available(): logger.warning( """ ROCm GPU architecture detection failed despite ROCm being available. """, ) - return "unknown" + return "unknown" diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index a42b026f7..59f8f24c9 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,7 +1,10 @@ +from types import SimpleNamespace + import pytest +import torch from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.cuda_specs import CUDASpecs, get_rocm_gpu_arch @pytest.fixture @@ -96,3 +99,21 @@ def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spe monkeypatch.setenv("BNB_CUDA_VERSION", "110") with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*not a CUDA build"): get_cuda_bnb_library_path(rocm70_spec) + + +def test_get_rocm_gpu_arch_from_torch(monkeypatch): + """With a visible GPU, the architecture comes from torch device properties, without a subprocess.""" + monkeypatch.setattr(torch.version, "hip", "7.0.0") + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr( + torch.cuda, + "get_device_properties", + lambda device: SimpleNamespace(gcnArchName="gfx90a:sramecc+:xnack-"), + ) + assert get_rocm_gpu_arch() == "gfx90a" + + +def test_get_rocm_gpu_arch_non_rocm(monkeypatch): + """On non-ROCm builds, no detection is attempted.""" + monkeypatch.setattr(torch.version, "hip", None) + assert get_rocm_gpu_arch() == "unknown"