Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import dataclasses
from functools import lru_cache
import logging
import os
import platform
import re
import subprocess
import sys
from typing import Optional

import torch
Expand Down Expand Up @@ -82,30 +84,48 @@ def get_cuda_specs() -> Optional[CUDASpecs]:
def get_rocm_gpu_arch() -> str:
"""Get ROCm GPU architecture."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
# On Windows, use hipinfo.exe; on Linux, use rocminfo
if platform.system() == "Windows":
cmd = ["hipinfo.exe"]
arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)"
else:
cmd = ["rocminfo"]
arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"

if not torch.version.hip:
return "unknown"

# Prefer the architecture torch already knows; this needs no subprocess.
if torch.cuda.is_available():
try:
# gcnArchName may include feature flags, e.g. "gfx90a:sramecc+:xnack-".
return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
except Exception as e:
logger.debug(f"Could not get ROCm GPU architecture from torch: {e}")

# Fall back to parsing tool output. On Windows, use hipInfo.exe; on Linux, use rocminfo.
if platform.system() == "Windows":
# hipInfo.exe is usually not on PATH: the HIP SDK does not add its bin directory,
# and AMD's PyTorch wheels for Windows ship hipInfo.exe next to python.exe instead.
cmds = [
["hipinfo.exe"],
[os.path.join(os.path.dirname(sys.executable), "hipInfo.exe")],
]
arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)"
else:
cmds = [["rocminfo"]]
arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"

last_error: Optional[Exception] = None
for cmd in cmds:
try:
result = subprocess.run(cmd, capture_output=True, text=True)
match = re.search(arch_pattern, result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
except Exception as e:
last_error = e
continue
match = re.search(arch_pattern, result.stdout)
if match:
return "gfx" + match.group(1)

if last_error is not None:
logger.error(f"Could not detect ROCm GPU architecture: {last_error}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
return "unknown"
23 changes: 22 additions & 1 deletion tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from types import SimpleNamespace

import pytest
import torch

from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.cuda_specs import CUDASpecs, get_rocm_gpu_arch


@pytest.fixture
Expand Down Expand Up @@ -96,3 +99,21 @@ def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spe
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*not a CUDA build"):
get_cuda_bnb_library_path(rocm70_spec)


def test_get_rocm_gpu_arch_from_torch(monkeypatch):
"""With a visible GPU, the architecture comes from torch device properties, without a subprocess."""
monkeypatch.setattr(torch.version, "hip", "7.0.0")
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(
torch.cuda,
"get_device_properties",
lambda device: SimpleNamespace(gcnArchName="gfx90a:sramecc+:xnack-"),
)
assert get_rocm_gpu_arch() == "gfx90a"


def test_get_rocm_gpu_arch_non_rocm(monkeypatch):
"""On non-ROCm builds, no detection is attempted."""
monkeypatch.setattr(torch.version, "hip", None)
assert get_rocm_gpu_arch() == "unknown"