Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ void kmeansPlusPlus(raft::resources const& handle,
raft::device_matrix_view<DataT, IndexT> candidates_view(
centroidCandidates.data_handle(), n_trials, n_features);

// L2 norm of X: ||c||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);

if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, X, L2NormX.view(), raft::sqrt_op{});
} else if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

Expand Down Expand Up @@ -342,13 +344,15 @@ void kmeans_fit_main(raft::resources const& handle,

rmm::device_scalar<DataT> clusterCostD(stream);

// L2 norm of X: ||x||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
auto l2normx_view =
raft::make_device_vector_view<const DataT, IndexT>(L2NormX.data_handle(), n_samples);

if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, X, L2NormX.view(), raft::sqrt_op{});
} else if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

Expand Down Expand Up @@ -523,10 +527,12 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
// destructor releases the resource
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);

// L2 norm of X: ||x||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, X, L2NormX.view(), raft::sqrt_op{});
} else if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

Expand Down Expand Up @@ -933,10 +939,12 @@ void kmeans_predict(raft::resources const& handle,
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);

// L2 norm of X: ||x||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, X, L2NormX.view(), raft::sqrt_op{});
} else if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

Expand Down
86 changes: 18 additions & 68 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#pragma once

#include "../../distance/fused_distance_nn.cuh"
#include "kmeans_common.cuh"
#include <cuvs/cluster/kmeans.hpp>

Expand Down Expand Up @@ -88,80 +87,31 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
auto stream = raft::resource::get_cuda_stream(handle);
switch (params.metric) {
case cuvs::distance::DistanceType::L2Expanded:
case cuvs::distance::DistanceType::L2SqrtExpanded: {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, raft::make_extents<IdxT>((sizeof(int)) * n_rows));

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, raft::make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle,
raft::make_device_matrix_view<const MathT, IdxT, raft::row_major>(centers, n_clusters, dim),
centroidsNorm.view());

cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
(params.metric == cuvs::distance::DistanceType::L2Expanded) ? false : true,
false,
true,
params.metric,
0.0f,
stream);

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
raft::linalg::map(handle,
raft::make_const_mdspan(minClusterAndDistance.view()),
raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
break;
}
case cuvs::distance::DistanceType::L2SqrtExpanded:
case cuvs::distance::DistanceType::CosineExpanded: {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, raft::make_extents<IdxT>((sizeof(int)) * n_rows));
rmm::device_uvector<MathT> L2NormBuf_OR_DistBuf(0, stream, mr);
rmm::device_uvector<char> workspace(0, stream, mr);

auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(dataset, n_rows, dim);
auto centroids_view =
raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, raft::make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
handle,
raft::make_device_matrix_view<const MathT, IdxT, raft::row_major>(centers, n_clusters, dim),
centroidsNorm.view(),
raft::sqrt_op{});

cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
false,
false,
true,
X_view,
centroids_view,
minClusterAndDistance.view(),
X_norm_view,
L2NormBuf_OR_DistBuf,
params.metric,
0.0f,
stream);
0, // batch_samples (unused for fused reduction)
0, // batch_centroids (unused for fused reduction)
workspace);

// Copy keys to output labels
raft::linalg::map(handle,
raft::make_const_mdspan(minClusterAndDistance.view()),
Expand Down
1 change: 0 additions & 1 deletion cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include "../../distance/distance.cuh"
#include "../../distance/fused_distance_nn.cuh"
#include <cstdint>
#include <cuvs/cluster/kmeans.hpp>
#include <cuvs/distance/distance.hpp>
Expand Down
Loading
Loading