Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
4efdd1a
Add channelwise conv
pfultz2 Feb 16, 2026
a0c6b07
Format
pfultz2 Feb 16, 2026
efeafca
Use shared memory
pfultz2 Feb 16, 2026
4498934
Format
pfultz2 Feb 16, 2026
1792edb
Update slice functions
pfultz2 Feb 16, 2026
0304972
Format
pfultz2 Feb 16, 2026
1389ae5
Update to use slices instead
pfultz2 Feb 16, 2026
9c9b9a5
Format
pfultz2 Feb 16, 2026
207e5d6
Add reduce_schedule for outer batches
pfultz2 Feb 16, 2026
cdae8f4
Format
pfultz2 Feb 16, 2026
b51b82f
Use pooling_reduce
pfultz2 Feb 16, 2026
b5f4f0f
Format
pfultz2 Feb 16, 2026
15fd39f
Some refactoring to use tiling
pfultz2 Feb 16, 2026
b61daa3
FOrmat
pfultz2 Feb 16, 2026
c9d258f
Access directly
pfultz2 Feb 16, 2026
6d979f5
Format
pfultz2 Feb 16, 2026
ecbce52
Add join
pfultz2 Feb 16, 2026
4bd6556
Update tuning
pfultz2 Feb 16, 2026
d1da333
Format
pfultz2 Feb 16, 2026
9cc6906
Add multi-output
pfultz2 Feb 17, 2026
0942c87
Format
pfultz2 Feb 17, 2026
ca147d2
Add spatial tiler
pfultz2 Feb 17, 2026
3b17a09
Format
pfultz2 Feb 17, 2026
037d10f
Avoid bounds check when there is no padding
pfultz2 Feb 17, 2026
7bc6d78
Remove lines
pfultz2 Feb 17, 2026
e3077b8
Use functions instead of variables
pfultz2 Feb 17, 2026
414aab4
Format
pfultz2 Feb 17, 2026
e56c4f1
Inine methods
pfultz2 Feb 17, 2026
b51c74f
Format
pfultz2 Feb 17, 2026
3d4bfe4
Update quick tuning list
pfultz2 Feb 17, 2026
a362a19
Format
pfultz2 Feb 17, 2026
208c7ad
Add another config
pfultz2 Feb 18, 2026
f2daa29
Add more configs
pfultz2 Feb 18, 2026
36110cf
Format
pfultz2 Feb 18, 2026
882fe3b
Add pointwise fusion
pfultz2 Mar 2, 2026
24a2645
Format
pfultz2 Mar 2, 2026
28e32af
Only enable for float and navi
pfultz2 Mar 2, 2026
e35373c
Format
pfultz2 Mar 2, 2026
f69d9bb
Fix tidy
pfultz2 Mar 2, 2026
fb48be7
Format
pfultz2 Mar 2, 2026
ef923a8
Fix tidy
pfultz2 Mar 2, 2026
513fafc
Update year
pfultz2 Mar 2, 2026
ec3c657
Fix cppcheck
pfultz2 Mar 2, 2026
5d8051b
Format
pfultz2 Mar 2, 2026
99c896c
Use std algos
pfultz2 Mar 2, 2026
9f0903d
Format
pfultz2 Mar 2, 2026
680328b
Move in_bounds function
pfultz2 Mar 2, 2026
1120309
Rename type
pfultz2 Mar 2, 2026
7645792
Format
pfultz2 Mar 2, 2026
32b5894
Fix compilation failure
pfultz2 Mar 2, 2026
2141264
Format
pfultz2 Mar 2, 2026
19cf173
Simplify some more
pfultz2 Mar 2, 2026
b39416e
Format
pfultz2 Mar 2, 2026
6c990fd
Use std::transform
pfultz2 Mar 2, 2026
90638f8
Precompute slices
pfultz2 Mar 2, 2026
053bf4f
Format
pfultz2 Mar 2, 2026
ffaa5c3
Update src/targets/gpu/kernels/include/migraphx/kernels/slice.hpp
pfultz2 Mar 2, 2026
8a06baf
Change the navi check
pfultz2 Mar 2, 2026
a3fd388
Merge branch 'channelwise-conv2' of github.com:ROCmSoftwarePlatform/A…
pfultz2 Mar 2, 2026
258af41
Split verify classes
pfultz2 Mar 2, 2026
bcd468d
Revert the reduce and index changes
pfultz2 Mar 2, 2026
7ba2cca
Revert pooling changes
pfultz2 Mar 2, 2026
61f6ffb
Use signed integer
pfultz2 Mar 2, 2026
2a770dd
Merge branch 'develop' into channelwise-conv2
pfultz2 Mar 2, 2026
b5cad75
Update year
pfultz2 Mar 2, 2026
5b49459
Format
pfultz2 Mar 2, 2026
dc7f7e5
Fix merge conflicts
pfultz2 Mar 3, 2026
9eb50da
Merge branch 'develop' into channelwise-conv2
TedThemistokleous Mar 9, 2026
18a7efa
Support padding
pfultz2 Apr 3, 2026
c23a8e8
Format
pfultz2 Apr 3, 2026
747292c
Fix selection
pfultz2 Apr 3, 2026
ad9b8d1
Fix padding
pfultz2 Apr 4, 2026
77dac35
Cleanup
pfultz2 Apr 4, 2026
3c3e0ac
Merge
pfultz2 Apr 4, 2026
362ce5f
Merge branch 'develop' into channelwise-conv2
pfultz2 Apr 7, 2026
c47b394
Use generate_array instead
pfultz2 Apr 7, 2026
604d408
Use generate array
pfultz2 Apr 8, 2026
5fc446a
Format
pfultz2 Apr 8, 2026
21442c4
Add padding tests
pfultz2 Apr 8, 2026
371f79b
Format
pfultz2 Apr 8, 2026
df1676f
Merge branch 'channelwise-conv2' of github.com:ROCmSoftwarePlatform/A…
pfultz2 Apr 8, 2026
8949117
Update is_padded() check
pfultz2 Apr 8, 2026
be32bda
Format
pfultz2 Apr 8, 2026
4f5221e
Add unit tests
pfultz2 Apr 8, 2026
b0e4634
Format
pfultz2 Apr 8, 2026
de0d67a
Merge branch 'develop' into channelwise-conv2
pfultz2 Apr 18, 2026
eecf785
Fix tidy
pfultz2 Apr 22, 2026
7949e82
Merge branch 'develop' into channelwise-conv2
pfultz2 Apr 22, 2026
a3b61a2
Fix cppcheck warnings
pfultz2 Apr 22, 2026
1457b47
Format
pfultz2 Apr 22, 2026
7a5abf9
Update year
pfultz2 Apr 22, 2026
b68efb5
Merge branch 'develop' into channelwise-conv2
causten Apr 26, 2026
003f033
Merge branch 'develop' into channelwise-conv2
pfultz2 May 11, 2026
1f22234
Merge branch 'develop' into channelwise-conv2
pfultz2 May 14, 2026
7d9e876
Fix tile miscompilation
pfultz2 May 15, 2026
c19c8b5
update license
kahmed10 May 21, 2026
75f1529
Merge branch 'develop' of https://github.com/ROCm/AMDMIGraphX into ch…
kahmed10 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ constexpr auto make_const_array(T x, Ts... xs)
return integral_const_array<typename T::value_type, x, xs...>{};
}

template <class T, class N, class F>
constexpr auto generate_const_array(N n, F f)
{
return sequence_c<n>([=](auto... is) { return make_const_array(f(is)...); });
}

template <class T, class N, class F>
constexpr auto generate_array(N n, F f)
{
Expand Down
11 changes: 10 additions & 1 deletion src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -206,6 +206,14 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)

#ifdef CPPCHECK
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME assert(cond)
#define MIGRAPHX_UNREACHABLE assert(false)
#define MIGRAPHX_ASSERT(cond) assert(cond)
#define MIGRAPHX_WARN(cond, ...) assert(cond)
#else
#ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
Expand All @@ -221,6 +229,7 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif
#endif

#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
Expand Down
22 changes: 9 additions & 13 deletions src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* The MIT License (MIT)
*
* Copyright (C) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -59,7 +59,7 @@ class numeric_limits;
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data;
uint8_t data = 0;
// default constructor
__device__ constexpr float8() = default;
// default copy constructor
Expand Down Expand Up @@ -140,7 +140,7 @@ struct float8
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
{
if(__builtin_is_constant_evaluated() or !FNUZ)
if(__builtin_is_constant_evaluated() or not FNUZ)
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
Expand Down Expand Up @@ -249,7 +249,7 @@ struct float8
// upcast using device specific intrinsic
constexpr __device__ operator float() const
{
if(__builtin_is_constant_evaluated() or !FNUZ)
if(__builtin_is_constant_evaluated() or not FNUZ)
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
Expand All @@ -261,7 +261,7 @@ struct float8
else
{
float fval = 0;
uint32_t i32val = static_cast<uint32_t>(data);
uint32_t i32val = data;

// upcast
if constexpr(T == migraphx::fp8::f8_type::fp8)
Expand Down Expand Up @@ -312,7 +312,7 @@ struct float8
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
if constexpr(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF);
Expand All @@ -333,7 +333,7 @@ struct float8
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
if constexpr(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7C) or (data == 0xFC);
}
Expand Down Expand Up @@ -370,16 +370,12 @@ struct float8

__device__ constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we < them;
return static_cast<float>(*this) < static_cast<float>(rhs);
}

__device__ constexpr bool operator>(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we > them;
return static_cast<float>(*this) > static_cast<float>(rhs);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng
if(x == 0)
return 0;
// handle negative zero
// cppcheck-suppress compareValueOutOfTypeRangeError
else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{
return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -131,6 +131,9 @@ struct is_integral_constant<integral_constant<T, V>> : true_type
template <index_int N>
using index_constant = integral_constant<index_int, N>;

template <index_int N>
static constexpr auto index_c = index_constant<N>{};

template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,12 @@ template <index_int N>
constexpr auto slice_group()
{
return slice_size_transform{[](auto input, auto s) {
auto r = return_array_c([] {
return return_array_c([] {
auto lens = decltype(s){}.lens.base();
lens.back() *= N;
lens -= 1;
return decltype(input){}.lens.carry(lens) + index_int{1};
});
return r;
}};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ constexpr bool has_nonzero(index_ints<Ps...>)
return ((Ps != 0) or ...);
}

template <index_int NTiles, class TileLens, class OutputShape, class Padding = index_ints<>>
template <index_int NTiles, class TileLens, class OutputShape, class Padding = index_ints<0>>
struct spatial_tiler
{
static constexpr auto keep_spatial()
Expand Down Expand Up @@ -72,39 +72,50 @@ struct spatial_tiler
static constexpr index_int tiles_total() { return tiles_per_dim().product(); }
static constexpr auto ndim() { return out_spatial_lens().size(); }

static constexpr bool is_padded()
static constexpr auto get_padding()
{
return (out_spatial_lens() != tiles_per_dim() * output_lens());
if constexpr(Padding{}.size() < 2)
{
auto pre = transform(TileLens{}, [](auto) { return index_c<0>; });
return join(pre, pre);
}
else
{
return Padding{};
}
}

static constexpr bool has_conv_padding() { return has_nonzero(Padding{}); }

// Left (begin) padding per dim: (0, 0, left_h, left_w)
static constexpr auto left_padding()
{
return return_array_c([] {
constexpr auto p = Padding{};
constexpr auto ns = p.size() / 2;
auto result = array<index_int, ns + 2>(index_int{0});
for(index_int i = 0; i < ns; i++)
result[i + 2] = p[i];
return result;
constexpr auto p = get_padding();
constexpr auto ns = p.size() / 2;
return generate_const_array<index_int>(_c<ns + 2>, [&](auto i) {
if constexpr(i < 2)
return index_c<0>;
else
return index_c<p[i - 2]>;
});
}

// Total (left+right) padding per dim: (0, 0, left_h+right_h, left_w+right_w)
static constexpr auto total_padding()
{
return return_array_c([] {
constexpr auto p = Padding{};
constexpr auto ns = p.size() / 2;
auto result = array<index_int, ns + 2>(index_int{0});
for(index_int i = 0; i < ns; i++)
result[i + 2] = p[i] + p[i + ns];
return result;
constexpr auto p = get_padding();
constexpr auto ns = p.size() / 2;
return generate_const_array<index_int>(_c<ns + 2>, [&](auto i) {
if constexpr(i < 2)
return index_c<0>;
else
return index_c<p[i - 2] + p[i - 2 + ns]>;
});
}

static constexpr bool is_padded()
{
return (out_spatial_lens() != (tiles_per_dim() * output_lens() + total_padding()));
}

Comment on lines +114 to +118
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_padded() now returns true whenever convolution padding is present (because total_padding() is included in the equality check). This forces for_each() to do per-element output bounds checks even when the output tile exactly covers the output (i.e., only input halo needs padding checks). Consider splitting this into two constexpr predicates (e.g., is_output_padded() for tile overhang vs needs_input_padding() for halo/padding), so output bounds checks remain compiled out for the common “exact tiling + conv padding” case.

Suggested change
static constexpr bool is_padded()
{
return (out_spatial_lens() != (tiles_per_dim() * output_lens() + total_padding()));
}
static constexpr bool is_output_padded()
{
return (out_spatial_lens() != (tiles_per_dim() * output_lens()));
}
static constexpr bool needs_input_padding()
{
return has_nonzero(total_padding());
}
static constexpr bool is_padded() { return is_output_padded(); }

Copilot uses AI. Check for mistakes.
index idx;
array<index_int, ndim()> tile_origin;

Expand All @@ -114,19 +125,10 @@ struct spatial_tiler
static constexpr auto halo_lens_for()
{
constexpr auto halo_extra = [] {
if constexpr(has_conv_padding())
{
return return_array_c([] {
return make_slice(InputShape{}, keep_spatial()).lens - out_spatial_lens() +
total_padding();
});
}
else
{
constexpr auto input_spatial = make_slice(InputShape{}, keep_spatial()).lens;
return transform(
input_spatial, out_spatial_lens(), [](auto is, auto os) { return is - os; });
}
return return_array_c([] {
return make_slice(InputShape{}, keep_spatial()).lens - out_spatial_lens() +
total_padding();
});
}();
return transform(output_lens(), halo_extra, [](auto o, auto h) { return o + h; });
}
Expand Down Expand Up @@ -167,19 +169,14 @@ struct spatial_tiler
idx.local_stride(_c<hl.product()>, [&](auto i) {
auto halo_multi = halo_shape.multi(i);
auto src_pos = tile_origin + halo_multi;
if constexpr(has_conv_padding())
auto input_pos = src_pos - left_padding();
if constexpr(is_padded())
{
constexpr auto pad = left_padding();
auto input_pos = src_pos - pad;
smem[i] = in_bounds(input_pos, input_spatial) ? type{input_ch[input_pos]} : type{0};
}
else if constexpr(is_padded())
{
smem[i] = in_bounds(src_pos, input_spatial) ? type{input_ch[src_pos]} : type{0};
}
else
{
smem[i] = input_ch[src_pos];
smem[i] = input_ch[input_pos];
}
});

Expand All @@ -203,7 +200,7 @@ struct spatial_tiler
}
};

template <index_int NTiles, class TileLens, class OutputShape, class Padding = index_ints<>>
template <index_int NTiles, class TileLens, class OutputShape, class Padding = index_ints<0>>
__device__ auto make_spatial_tiler(index idx, TileLens, OutputShape, Padding = {})
{
using tiler_type = spatial_tiler<NTiles, TileLens, OutputShape, Padding>;
Expand Down
11 changes: 6 additions & 5 deletions src/targets/gpu/kernels/include/migraphx/kernels/tile.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/uninitialized_buffer.hpp>
#include <migraphx/kernels/copy.hpp>

namespace migraphx {
Expand Down Expand Up @@ -61,8 +62,8 @@ struct tile
using type = typename T::type;
constexpr auto s = pad_shape(make_packed_shape(get_shape_c<T>{}));
constexpr auto size = s.element_space();
__shared__ type buffer[size];
auto b = make_tensor_view(buffer, s);
__shared__ uninitialized_buffer<type, size> buffer;
auto b = make_tensor_view(buffer.data(), s);
local_tensor_copy(idx, x, b);
f(b);
};
Expand All @@ -77,8 +78,8 @@ struct tile
using type = typename T::type;
constexpr auto s = pad_shape(make_packed_shape(get_shape_c<T>{}));
constexpr auto size = s.element_space();
__shared__ type buffer[size];
auto b = make_tensor_view(buffer, s);
__shared__ uninitialized_buffer<type, size> buffer;
auto b = make_tensor_view(buffer.data(), s);
f(b);
local_tensor_copy(idx, b, x);
};
Expand Down
Loading
Loading