Reuse minClusterAndDistance Helper for Balanced KMeans#2001
Reuse minClusterAndDistance Helper for Balanced KMeans#2001tarang-jain wants to merge 20 commits intorapidsai:mainfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
jinsolp
left a comment
There was a problem hiding this comment.
Thanks @tarang-jain ! Suggesting small changes and adding a question:
Co-authored-by: Jinsol Park <jinsolp@nvidia.com>
…/cuvs into hierarchical-helpers
| raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>( | ||
| handle, | ||
| raft::make_device_matrix_view<const DataT, IndexT>( | ||
| centroids.data_handle(), centroids.extent(0), centroids.extent(1)), |
There was a problem hiding this comment.
should we be computing norms like this here too?
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, centroids, centroidsNorm, raft::sqrt_op{});
} else {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, centroids, centroidsNorm);
}
There was a problem hiding this comment.
Good catch! I have fixed this.
| n_clusters, | ||
| n_features, | ||
| (void*)workspace.data(), | ||
| metric != cuvs::distance::DistanceType::L2Expanded, |
There was a problem hiding this comment.
Looking at the deleted code in kmeans_balanced.cuh, this used to be false for the CosineExpended metric. However, this condition passes true for the CosineExpended metric and sqrt-s the distances output.
There was a problem hiding this comment.
Yes that is because Cosine is now supported in our fused kernel.
There was a problem hiding this comment.
This wasn't being caught because we were lacking tests for the cosine metric. We always test L2 (default). I am trying to add cosine inputs to the tests next.
|
Could you please refresh my understanding on couple math questions?
|
Cosine is only supported in balanced kmeans. When we call minClusterAndDistance, the data is not normalized, but the centroids are.
sqrt is needed because the cosine distance op directly divides by the norms (it does not do the sqrt): |
The norm computation + fused reduction is already present in the minClusterDistanceCompute function. We can reuse that for balanced kmeans.
Furthermore, this PR updates the minClusterDistanceCompute function to also use the fused kernel for the cosine metric.
Binary size savings (conda-cpp-build check with CUDA 12.9.1 + amd):
main: 660.72 MB
This PR: 648.83 MB
(conda-cpp-build check with CUDA 13.1.1 + amd):
main: 305.70 B
This PR: 300.63 MB