Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions transformer_engine/common/util/standalone_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,6 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const
} else {
static_assert(sizeof(WideT) % sizeof(T) == 0);
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
// TODO: it's UB
union {
WideT scalar;
T array[items_per_scalar]; // NOLINT(runtime/arrays)
} wide;

int skip_cnt =
(reinterpret_cast<size_t>(in) % sizeof(WideT))
Expand All @@ -198,11 +193,13 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const
const idxT len_cast = (len - skip_cnt) / items_per_scalar;

for (idxT i = thread_rank; i < len_cast; i += num_threads) {
wide.scalar = in_cast[i];
const WideT wide_data = in_cast[i];
T local_array[items_per_scalar]; // NOLINT(runtime/arrays)
__builtin_memcpy(local_array, &wide_data, sizeof(WideT));
const idxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j) {
f(wide.array[j], real_i + j);
f(local_array[j], real_i + j);
}
}

Expand Down Expand Up @@ -236,10 +233,6 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width
} else {
static_assert(sizeof(WideT) % sizeof(T) == 0);
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
union {
WideT scalar;
T array[items_per_scalar]; // NOLINT(runtime/arrays)
} wide;

int skip_cnt =
(reinterpret_cast<size_t>(in) % sizeof(WideT))
Expand All @@ -251,16 +244,24 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width
const WideT *in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
const idxT len_cast = (len - skip_cnt) / items_per_scalar;

const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
for (idxT i = tid; i < len_cast_for_sync; i += stride) {
bool valid = i < len_cast;
if (valid) {
wide.scalar = in_cast[i];
}
const idxT real_i = skip_cnt + i * items_per_scalar;
// Skip when no full vector chunk exists: avoids len_cast_for_sync underflow and
// OOB companion reads (in_cast[0] needs at least one valid WideT).
if (len_cast > 0) {
const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
for (idxT i = tid; i < len_cast_for_sync; i += stride) {
const bool valid = i < len_cast;
// Unconditional 128-bit vector load: invalid threads read in_cast[0] (cached,
// discarded via valid=false) so NVCC emits LDG.E.128 instead of predicated load.
// Index clamping (not pointer ternary) avoids C++ UB from &in_cast[i] when i >= len_cast.
const idxT safe_i = valid ? i : static_cast<idxT>(0);
const WideT wide_data = in_cast[safe_i];
T local_array[items_per_scalar]; // NOLINT(runtime/arrays)
__builtin_memcpy(local_array, &wide_data, sizeof(WideT));
const idxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j) {
f(wide.array[j], real_i + j, valid);
for (int j = 0; j < items_per_scalar; ++j) {
f(local_array[j], real_i + j, valid);
}
}
}

Expand Down