Skip to content

Commit e047570

Browse files
Add cuDLA bindings (#2034)
* Add cuDLA bindings Generated from cudla.h using cybind. Files added: - cycudla.pxd/pyx: Cython layer exposing C header types and functions - cudla.pxd/pyx: lowpp Python layer with POD classes, enums, and wrappers - _internal/cudla.pxd, cudla_linux.pyx, cudla_windows.pyx: dynamic library loading - docs/source/module/cudla.rst: API documentation - tests/cudla/: pytest unit tests for enums, POD types, error handling, API surface, and hardware-gated function tests (verified on L4T/Orin) Build/CI updates: - pyproject.toml: added cudla to cuda-toolkit optional dependencies - .github/actions/fetch_ctk/action.yml: added libcudla to CTK components - docs/source/api.rst: added cudla to toctree * fixup: ruff lint fixes * Add SPDX license headers to cuDLA binding files * fixed SPDX license headers format * Fix fetch_ctk redistrib component resolution. Use redistrib metadata to skip unsupported mini-CTK components and resolve archive paths through a tested helper, including container-safe workspace paths for runtime jobs. * Remove cudla_windows.pyx and hardware-gated tests * Fix cudla.rst: remove unimplemented functions * fixed ruff failure in test_cudla_bindings.py --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
1 parent 41cea4c commit e047570

14 files changed

Lines changed: 3176 additions & 27 deletions

File tree

.github/actions/fetch_ctk/action.yml

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ inputs:
1414
cuda-components:
1515
description: "A list of the CTK components to install as a comma-separated list. e.g. 'cuda_nvcc,cuda_nvrtc,cuda_cudart'"
1616
required: false
17-
default: "cuda_nvcc,cuda_cudart,cuda_crt,libnvvm,cuda_nvrtc,cuda_profiler_api,cuda_cccl,cuda_cupti,libnvjitlink,libcufile,libnvfatbin"
17+
default: "cuda_nvcc,cuda_cudart,cuda_crt,libnvvm,cuda_nvrtc,cuda_profiler_api,cuda_cccl,cuda_cupti,libnvjitlink,libcufile,libnvfatbin,libcudla"
1818
cuda-path:
1919
description: "where the CTK components will be installed to, relative to $PWD"
2020
required: false
@@ -27,24 +27,15 @@ runs:
2727
shell: bash --noprofile --norc -xeuo pipefail {0}
2828
run: |
2929
# Pre-process the component list to ensure hash uniqueness
30+
# Use the runtime workspace mount so this also works inside container jobs.
31+
CTK_REDIST_TOOL="${GITHUB_WORKSPACE}/ci/tools/fetch_ctk_redistrib.py"
3032
CTK_CACHE_COMPONENTS=${{ inputs.cuda-components }}
31-
# Conditionally strip out libnvjitlink for CUDA versions < 12
32-
CUDA_MAJOR_VER="$(cut -d '.' -f 1 <<< ${{ inputs.cuda-version }})"
33-
if [[ "$CUDA_MAJOR_VER" -lt 12 ]]; then
34-
CTK_CACHE_COMPONENTS="${CTK_CACHE_COMPONENTS//libnvjitlink/}"
35-
fi
36-
# Conditionally strip out cuda_crt and libnvvm for CUDA versions < 13
37-
CUDA_MAJOR_VER="$(cut -d '.' -f 1 <<< ${{ inputs.cuda-version }})"
38-
if [[ "$CUDA_MAJOR_VER" -lt 13 ]]; then
39-
CTK_CACHE_COMPONENTS="${CTK_CACHE_COMPONENTS//cuda_crt/}"
40-
CTK_CACHE_COMPONENTS="${CTK_CACHE_COMPONENTS//libnvvm/}"
41-
fi
42-
# Conditionally strip out libcufile since it does not support Windows
43-
if [[ "${{ inputs.host-platform }}" == win-* ]]; then
44-
CTK_CACHE_COMPONENTS="${CTK_CACHE_COMPONENTS//libcufile/}"
45-
fi
46-
# Cleanup stray commas after removing components
47-
CTK_CACHE_COMPONENTS="${CTK_CACHE_COMPONENTS//,,/,}"
33+
CTK_JSON_URL="https://developer.download.nvidia.com/compute/cuda/redist/redistrib_${{ inputs.cuda-version }}.json"
34+
CTK_CACHE_COMPONENTS="$(python "$CTK_REDIST_TOOL" filter-components \
35+
--host-platform "${{ inputs.host-platform }}" \
36+
--cuda-version "${{ inputs.cuda-version }}" \
37+
--components "$CTK_CACHE_COMPONENTS" \
38+
--metadata-url "$CTK_JSON_URL")"
4839
4940
HASH=$(echo -n "${CTK_CACHE_COMPONENTS}" | sha256sum | awk '{print $1}')
5041
echo "CTK_CACHE_KEY=mini-ctk-${{ inputs.cuda-version }}-${{ inputs.host-platform }}-$HASH" >> $GITHUB_ENV
@@ -78,19 +69,17 @@ runs:
7869
mkdir $CACHE_TMP_DIR
7970
8071
# The binary archives (redist) are guaranteed to be updated as part of the release posting.
72+
# Use the runtime workspace mount so this also works inside container jobs.
73+
CTK_REDIST_TOOL="${GITHUB_WORKSPACE}/ci/tools/fetch_ctk_redistrib.py"
8174
CTK_BASE_URL="https://developer.download.nvidia.com/compute/cuda/redist/"
8275
CTK_JSON_URL="$CTK_BASE_URL/redistrib_${{ inputs.cuda-version }}.json"
76+
CTK_JSON_FILE="$CACHE_TMP_DIR/redistrib.json"
77+
curl -LSs "$CTK_JSON_URL" -o "$CTK_JSON_FILE"
8378
if [[ "${{ inputs.host-platform }}" == linux* ]]; then
84-
if [[ "${{ inputs.host-platform }}" == "linux-64" ]]; then
85-
CTK_SUBDIR="linux-x86_64"
86-
elif [[ "${{ inputs.host-platform }}" == "linux-aarch64" ]]; then
87-
CTK_SUBDIR="linux-sbsa"
88-
fi
8979
function extract() {
9080
tar -xvf $1 -C $CACHE_TMP_DIR --strip-components=1
9181
}
9282
elif [[ "${{ inputs.host-platform }}" == "win-64" ]]; then
93-
CTK_SUBDIR="windows-x86_64"
9483
function extract() {
9584
_TEMP_DIR_=$(mktemp -d)
9685
unzip $1 -d $_TEMP_DIR_
@@ -106,8 +95,10 @@ runs:
10695
curl -LSs $1 -o $2
10796
}
10897
CTK_COMPONENT=$1
109-
CTK_COMPONENT_REL_PATH="$(curl -s $CTK_JSON_URL |
110-
python -c "import sys, json; print(json.load(sys.stdin)['${CTK_COMPONENT}']['${CTK_SUBDIR}']['relative_path'])")"
98+
CTK_COMPONENT_REL_PATH="$(python "$CTK_REDIST_TOOL" component-relative-path \
99+
--host-platform "${{ inputs.host-platform }}" \
100+
--component "$CTK_COMPONENT" \
101+
--metadata-path "$CTK_JSON_FILE")"
111102
CTK_COMPONENT_URL="${CTK_BASE_URL}/${CTK_COMPONENT_REL_PATH}"
112103
CTK_COMPONENT_COMPONENT_FILENAME="$(basename $CTK_COMPONENT_REL_PATH)"
113104
download $CTK_COMPONENT_URL $CTK_COMPONENT_COMPONENT_FILENAME

ci/tools/fetch_ctk_redistrib.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#!/usr/bin/env python3
2+
#
3+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# SPDX-License-Identifier: Apache-2.0
6+
7+
"""Resolve mini-CTK components from NVIDIA redistrib metadata."""
8+
9+
from __future__ import annotations
10+
11+
import argparse
12+
import json
13+
import sys
14+
import urllib.error
15+
import urllib.parse
16+
import urllib.request
17+
from pathlib import Path
18+
from typing import Any
19+
20+
HOST_PLATFORM_TO_SUBDIR: dict[str, str] = {
21+
"linux-64": "linux-x86_64",
22+
"linux-aarch64": "linux-sbsa",
23+
"win-64": "windows-x86_64",
24+
}
25+
26+
27+
def host_platform_to_subdir(host_platform: str) -> str:
28+
try:
29+
return HOST_PLATFORM_TO_SUBDIR[host_platform]
30+
except KeyError as exc:
31+
raise ValueError(f"unsupported host-platform: {host_platform!r}") from exc
32+
33+
34+
def split_components(components: str) -> list[str]:
35+
return [component for component in components.split(",") if component]
36+
37+
38+
def filter_static_components(components: list[str], host_platform: str, cuda_version: str) -> list[str]:
39+
try:
40+
cuda_major = int(cuda_version.split(".", 1)[0])
41+
except ValueError as exc:
42+
raise ValueError(f"invalid cuda-version: {cuda_version!r}") from exc
43+
44+
filtered = []
45+
for component in components:
46+
if component == "libnvjitlink" and cuda_major < 12:
47+
continue
48+
if component in {"cuda_crt", "libnvvm"} and cuda_major < 13:
49+
continue
50+
if component == "libcufile" and host_platform.startswith("win-"):
51+
continue
52+
filtered.append(component)
53+
return filtered
54+
55+
56+
def validate_metadata_url(metadata_url: str) -> str:
57+
parsed = urllib.parse.urlsplit(metadata_url)
58+
if parsed.scheme != "https" or not parsed.netloc:
59+
raise ValueError(f"metadata URL must be an https URL: {metadata_url!r}")
60+
return metadata_url
61+
62+
63+
def load_metadata(*, metadata_path: str | None, metadata_url: str | None) -> dict[str, Any]:
64+
if (metadata_path is None) == (metadata_url is None):
65+
raise ValueError("exactly one of --metadata-path or --metadata-url is required")
66+
67+
if metadata_path is not None:
68+
return json.loads(Path(metadata_path).read_text(encoding="utf-8"))
69+
70+
assert metadata_url is not None
71+
metadata_url = validate_metadata_url(metadata_url)
72+
with urllib.request.urlopen(metadata_url) as response: # noqa: S310 - scheme is restricted to https above
73+
return json.load(response)
74+
75+
76+
def filter_components(
77+
metadata: dict[str, Any],
78+
*,
79+
host_platform: str,
80+
cuda_version: str,
81+
components: str,
82+
) -> tuple[list[str], list[str]]:
83+
ctk_subdir = host_platform_to_subdir(host_platform)
84+
filtered = []
85+
skipped = []
86+
for component in filter_static_components(split_components(components), host_platform, cuda_version):
87+
if ctk_subdir in metadata.get(component, {}):
88+
filtered.append(component)
89+
else:
90+
skipped.append(component)
91+
return filtered, skipped
92+
93+
94+
def get_component_relative_path(metadata: dict[str, Any], *, host_platform: str, component: str) -> str:
95+
ctk_subdir = host_platform_to_subdir(host_platform)
96+
component_info = metadata.get(component)
97+
if component_info is None:
98+
raise KeyError(f"unknown CTK component {component!r}")
99+
100+
subdir_info = component_info.get(ctk_subdir)
101+
if subdir_info is None:
102+
raise KeyError(f"CTK component {component!r} is not available for redistrib subdir {ctk_subdir!r}")
103+
104+
relative_path = subdir_info.get("relative_path")
105+
if relative_path is None:
106+
raise KeyError(f"CTK component {component!r} for redistrib subdir {ctk_subdir!r} is missing 'relative_path'")
107+
return relative_path
108+
109+
110+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
111+
parser = argparse.ArgumentParser(description=__doc__)
112+
subparsers = parser.add_subparsers(dest="command", required=True)
113+
114+
filter_parser = subparsers.add_parser("filter-components")
115+
filter_parser.add_argument("--host-platform", required=True)
116+
filter_parser.add_argument("--cuda-version", required=True)
117+
filter_parser.add_argument("--components", required=True)
118+
filter_parser.add_argument("--metadata-path")
119+
filter_parser.add_argument("--metadata-url")
120+
121+
relpath_parser = subparsers.add_parser("component-relative-path")
122+
relpath_parser.add_argument("--host-platform", required=True)
123+
relpath_parser.add_argument("--component", required=True)
124+
relpath_parser.add_argument("--metadata-path")
125+
relpath_parser.add_argument("--metadata-url")
126+
127+
return parser.parse_args(argv)
128+
129+
130+
def main(argv: list[str] | None = None) -> int:
131+
args = parse_args(argv)
132+
133+
try:
134+
metadata = load_metadata(metadata_path=args.metadata_path, metadata_url=args.metadata_url)
135+
136+
if args.command == "filter-components":
137+
filtered, skipped = filter_components(
138+
metadata,
139+
host_platform=args.host_platform,
140+
cuda_version=args.cuda_version,
141+
components=args.components,
142+
)
143+
for component in skipped:
144+
print(
145+
f"Skipping unsupported CTK component {component!r} for host-platform {args.host_platform!r}",
146+
file=sys.stderr,
147+
)
148+
print(",".join(filtered))
149+
return 0
150+
151+
if args.command == "component-relative-path":
152+
print(
153+
get_component_relative_path(
154+
metadata,
155+
host_platform=args.host_platform,
156+
component=args.component,
157+
)
158+
)
159+
return 0
160+
161+
raise AssertionError(f"unexpected command: {args.command!r}")
162+
except (ValueError, KeyError, OSError, urllib.error.URLError, json.JSONDecodeError) as exc:
163+
print(f"ERROR: {exc}", file=sys.stderr)
164+
return 1
165+
166+
167+
if __name__ == "__main__":
168+
raise SystemExit(main())

0 commit comments

Comments
 (0)