Skip to content

Reuse minClusterAndDistance Helper for Balanced KMeans#2001

Open
tarang-jain wants to merge 20 commits intorapidsai:mainfrom
tarang-jain:hierarchical-helpers
Open

Reuse minClusterAndDistance Helper for Balanced KMeans#2001
tarang-jain wants to merge 20 commits intorapidsai:mainfrom
tarang-jain:hierarchical-helpers

Conversation

@tarang-jain
Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain commented Apr 8, 2026

The norm computation + fused reduction is already present in the minClusterDistanceCompute function. We can reuse that for balanced kmeans.

Furthermore, this PR updates the minClusterDistanceCompute function to also use the fused kernel for the cosine metric.

Binary size savings (conda-cpp-build check with CUDA 12.9.1 + amd):
main: 660.72 MB
This PR: 648.83 MB
(conda-cpp-build check with CUDA 13.1.1 + amd):
main: 305.70 B
This PR: 300.63 MB

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 8, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@aamijar aamijar moved this to In Progress in Unstructured Data Processing Apr 8, 2026
@aamijar aamijar added non-breaking Introduces a non-breaking change improvement Improves an existing functionality labels Apr 8, 2026
@tarang-jain tarang-jain marked this pull request as ready for review April 9, 2026 00:37
@tarang-jain tarang-jain requested a review from a team as a code owner April 9, 2026 00:37
Copy link
Copy Markdown
Member

@aamijar aamijar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor Tarang! Could you add a description to the PR?

Copy link
Copy Markdown
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tarang-jain ! Suggesting small changes and adding a question:

@tarang-jain tarang-jain requested a review from jinsolp April 10, 2026 21:14
Comment on lines 194 to 197
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle,
raft::make_device_matrix_view<const DataT, IndexT>(
centroids.data_handle(), centroids.extent(0), centroids.extent(1)),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be computing norms like this here too?

if (metric == cuvs::distance::DistanceType::CosineExpanded) {
      raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
        handle, centroids, centroidsNorm, raft::sqrt_op{});
    } else {
      raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
        handle, centroids, centroidsNorm);
    }

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have fixed this.

n_clusters,
n_features,
(void*)workspace.data(),
metric != cuvs::distance::DistanceType::L2Expanded,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the deleted code in kmeans_balanced.cuh, this used to be false for the CosineExpended metric. However, this condition passes true for the CosineExpended metric and sqrt-s the distances output.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is because Cosine is now supported in our fused kernel.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't being caught because we were lacking tests for the cosine metric. We always test L2 (default). I am trying to add cosine inputs to the tests next.

@achirkin
Copy link
Copy Markdown
Contributor

Could you please refresh my understanding on couple math questions?

  1. By the time we call minClusterDistanceCompute, the data is normalized, so we can use the same L2 norm everywhere, right?
  2. Why is sqrt needed for the cosine case?

@tarang-jain
Copy link
Copy Markdown
Contributor Author

tarang-jain commented Apr 14, 2026

By the time we call minClusterDistanceCompute, the data is normalized, so we can use the same L2 norm everywhere, right?

Cosine is only supported in balanced kmeans. When we call minClusterAndDistance, the data is not normalized, but the centroids are.

Why is sqrt needed for the cosine case?

sqrt is needed because the cosine distance op directly divides by the norms (it does not do the sqrt):

return static_cast<AccT>(1.0) - static_cast<AccT>(accVal / (aNorm * bNorm));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

4 participants