Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions c/src/core/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
#include <raft/util/cudart_utils.hpp>
#include <rapids_logger/logger.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/mr/cuda_memory_resource.hpp>
#include <rmm/mr/managed_memory_resource.hpp>
#include <rmm/mr/owning_wrapper.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>
#include <rmm/mr/pinned_host_memory_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include "../core/exceptions.hpp"

#include <cuda/memory_resource>

#include <cstdint>
#include <memory>
#include <optional>
#include <thread>

extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
Expand Down Expand Up @@ -132,60 +135,47 @@ extern "C" cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t byte
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto mr = rmm::mr::get_current_device_resource();
*ptr = mr->allocate(raft::resource::get_cuda_stream(*res_ptr), bytes);
auto mr = rmm::mr::get_current_device_resource_ref();
*ptr = mr.allocate(raft::resource::get_cuda_stream(*res_ptr), bytes);
});
}

extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes)
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto mr = rmm::mr::get_current_device_resource();
mr->deallocate(raft::resource::get_cuda_stream(*res_ptr), ptr, bytes);
auto mr = rmm::mr::get_current_device_resource_ref();
mr.deallocate(raft::resource::get_cuda_stream(*res_ptr), ptr, bytes);
});
}

thread_local std::shared_ptr<
rmm::mr::owning_wrapper<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>,
rmm::mr::device_memory_resource>>
pool_mr;
thread_local cuda::mr::any_resource<cuda::mr::device_accessible> pool_upstream;
thread_local std::optional<rmm::mr::pool_memory_resource> pool_mr;

extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent,
int max_pool_size_percent,
bool managed)
{
return cuvs::core::translate_exceptions([=] {
// Upstream memory resource needs to be a cuda_memory_resource
auto cuda_mr = rmm::mr::get_current_device_resource();
auto* cuda_mr_casted = dynamic_cast<rmm::mr::cuda_memory_resource*>(cuda_mr);
if (cuda_mr_casted == nullptr) {
throw std::runtime_error("Current memory resource is not a cuda_memory_resource");
}

auto initial_size = rmm::percent_of_free_device_memory(initial_pool_size_percent);
auto max_size = rmm::percent_of_free_device_memory(max_pool_size_percent);

auto mr = std::shared_ptr<rmm::mr::device_memory_resource>();
if (managed) {
mr = std::static_pointer_cast<rmm::mr::device_memory_resource>(
std::make_shared<rmm::mr::managed_memory_resource>());
pool_upstream = rmm::mr::managed_memory_resource{};
} else {
mr = std::static_pointer_cast<rmm::mr::device_memory_resource>(
std::make_shared<rmm::mr::cuda_memory_resource>());
pool_upstream = rmm::mr::cuda_memory_resource{};
}

pool_mr =
rmm::mr::make_owning_wrapper<rmm::mr::pool_memory_resource>(mr, initial_size, max_size);
pool_mr.emplace(pool_upstream, initial_size, max_size);

rmm::mr::set_current_device_resource(pool_mr.get());
rmm::mr::set_current_device_resource(*pool_mr);
});
}

extern "C" cuvsError_t cuvsRMMMemoryResourceReset()
{
return cuvs::core::translate_exceptions([=] {
rmm::mr::set_current_device_resource(rmm::mr::detail::initial_resource());
rmm::mr::reset_current_device_resource();
pool_mr.reset();
});
}
Expand Down
3 changes: 2 additions & 1 deletion ci/build_cpp.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

source rapids-configure-sccache
source rapids-date-string
Expand Down
2 changes: 2 additions & 0 deletions ci/build_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
Expand All @@ -18,6 +19,7 @@ RAPIDS_VERSION_MAJOR_MINOR="$(rapids-version-major-minor)"
export RAPIDS_VERSION_MAJOR_MINOR

rapids-dependency-file-generator \
"${RAPIDS_EXTRA_CONDA_CHANNEL_ARGS[@]}" \
--output conda \
--file-key docs \
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" \
Expand Down
2 changes: 2 additions & 0 deletions ci/build_go.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
Expand All @@ -14,6 +15,7 @@ rapids-logger "Configuring conda strict channel priority"
conda config --set channel_priority strict

rapids-dependency-file-generator \
"${RAPIDS_EXTRA_CONDA_CHANNEL_ARGS[@]}" \
--output conda \
--file-key go \
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" \
Expand Down
4 changes: 3 additions & 1 deletion ci/build_java.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

source rapids-configure-sccache

Expand Down Expand Up @@ -33,6 +34,7 @@ rapids-logger "Generate Java testing dependencies"
ENV_YAML_DIR="$(mktemp -d)"

rapids-dependency-file-generator \
"${RAPIDS_EXTRA_CONDA_CHANNEL_ARGS[@]}" \
--output conda \
--file-key java \
--prepend-channel "${CPP_CHANNEL}" \
Expand Down
1 change: 1 addition & 0 deletions ci/build_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

source rapids-configure-sccache
source rapids-date-string
Expand Down
4 changes: 3 additions & 1 deletion ci/build_rust.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

rapids-logger "Downloading artifacts from previous jobs"
CPP_CHANNEL=$(rapids-download-conda-from-github cpp)
Expand All @@ -14,6 +15,7 @@ rapids-logger "Configuring conda strict channel priority"
conda config --set channel_priority strict

rapids-dependency-file-generator \
"${RAPIDS_EXTRA_CONDA_CHANNEL_ARGS[@]}" \
--output conda \
--file-key rust \
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" \
Expand Down
1 change: 1 addition & 0 deletions ci/build_standalone_c.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

TOOLSET_VERSION=14
NINJA_VERSION=v1.13.1
Expand Down
3 changes: 2 additions & 1 deletion ci/build_wheel_cuvs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail

source rapids-init-pip

source ./ci/use_wheels_from_prs.sh

package_dir="python/cuvs"

RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen "${RAPIDS_CUDA_VERSION}")"
Expand Down
3 changes: 2 additions & 1 deletion ci/build_wheel_libcuvs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail

source rapids-init-pip

source ./ci/use_wheels_from_prs.sh

package_name="libcuvs"
package_dir="python/libcuvs"

Expand Down
2 changes: 2 additions & 0 deletions ci/test_cpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

. /opt/conda/etc/profile.d/conda.sh

Expand All @@ -13,6 +14,7 @@ CPP_CHANNEL=$(rapids-download-conda-from-github cpp)

rapids-logger "Generate C++ testing dependencies"
rapids-dependency-file-generator \
"${RAPIDS_EXTRA_CONDA_CHANNEL_ARGS[@]}" \
--output conda \
--file-key test_cpp \
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch)" \
Expand Down
2 changes: 2 additions & 0 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail
source ./ci/use_conda_packages_from_prs.sh

. /opt/conda/etc/profile.d/conda.sh

Expand All @@ -15,6 +16,7 @@ PYTHON_CHANNEL=$(rapids-download-from-github "$(rapids-package-name "conda_pytho

rapids-logger "Generate Python testing dependencies"
rapids-dependency-file-generator \
"${RAPIDS_EXTRA_CONDA_CHANNEL_ARGS[@]}" \
--output conda \
--file-key test_python \
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" \
Expand Down
3 changes: 2 additions & 1 deletion ci/test_wheel_cuvs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail

source rapids-init-pip

source ./ci/use_wheels_from_prs.sh

# Delete system libnccl.so to ensure the wheel is used
rm -rf /usr/lib64/libnccl*

Expand Down
24 changes: 24 additions & 0 deletions ci/use_conda_packages_from_prs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

# download CI artifacts
LIBRMM_CHANNEL=$(rapids-get-pr-artifact rmm 2361 cpp conda)
RMM_CHANNEL=$(rapids-get-pr-artifact rmm 2361 python conda --stable)
LIBRAFT_CHANNEL=$(rapids-get-pr-artifact raft 2996 cpp conda)
RAFT_CHANNEL=$(rapids-get-pr-artifact raft 2996 python conda --stable)

RAPIDS_PREPENDED_CONDA_CHANNELS=(
"${LIBRMM_CHANNEL}"
"${RMM_CHANNEL}"
"${LIBRAFT_CHANNEL}"
"${RAFT_CHANNEL}"
)
export RAPIDS_PREPENDED_CONDA_CHANNELS

RAPIDS_EXTRA_CONDA_CHANNEL_ARGS=()
for _channel in "${RAPIDS_PREPENDED_CONDA_CHANNELS[@]}"
do
conda config --system --add channels "${_channel}"
RAPIDS_EXTRA_CONDA_CHANNEL_ARGS+=(--prepend-channel "${_channel}")
done
25 changes: 25 additions & 0 deletions ci/use_wheels_from_prs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

RAPIDS_PY_CUDA_SUFFIX=$(rapids-wheel-ctk-name-gen "${RAPIDS_CUDA_VERSION}")

LIBRMM_WHEELHOUSE=$(
RAPIDS_PY_WHEEL_NAME="librmm_${RAPIDS_PY_CUDA_SUFFIX}" rapids-get-pr-artifact rmm 2361 cpp wheel
)
RMM_WHEELHOUSE=$(
rapids-get-pr-artifact rmm 2361 python wheel --stable
)
LIBRAFT_WHEELHOUSE=$(
RAPIDS_PY_WHEEL_NAME="libraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-get-pr-artifact raft 2996 cpp wheel
)
RAFT_WHEELHOUSE=$(
rapids-get-pr-artifact raft 2996 python wheel --stable --pkg_name pylibraft
)

cat >> "${PIP_CONSTRAINT}" <<EOF
librmm-${RAPIDS_PY_CUDA_SUFFIX} @ file://$(echo "${LIBRMM_WHEELHOUSE}"/librmm_*.whl)
rmm-${RAPIDS_PY_CUDA_SUFFIX} @ file://$(echo "${RMM_WHEELHOUSE}"/rmm_*.whl)
libraft-${RAPIDS_PY_CUDA_SUFFIX} @ file://$(echo "${LIBRAFT_WHEELHOUSE}"/libraft_*.whl)
pylibraft-${RAPIDS_PY_CUDA_SUFFIX} @ file://$(echo "${RAFT_WHEELHOUSE}"/pylibraft_*.whl)
EOF
Loading
Loading