Skip to content

Commit 6d993d1

Browse files
committed
cr fixes
Change-Id: I3a3fca2b842702672c915e1357602bf0fe47bd47
1 parent 5a45462 commit 6d993d1

3 files changed

Lines changed: 213 additions & 470 deletions

File tree

arm_compute/core/IReducibleTensor.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@
2626
#define ACL_ARM_COMPUTE_CORE_IREDUCIBLETENSOR_H
2727

2828
#include "arm_compute/core/SparseTensor.h"
29-
#include "arm_compute/runtime/COOTensor.h"
30-
#include "arm_compute/runtime/CSRTensor.h"
3129

3230
namespace arm_compute
3331
{
34-
/** */
32+
/** Forward declaration of COOTensor and CSRTensor class */
33+
class COOTensor;
34+
class CSRTensor;
35+
36+
/** Interface for all reducible tensors, i.e. all tensors that can be
37+
* converted to a sparse representation.
38+
*/
3539
class IReducibleTensor
3640
{
3741
public:
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Copyright (c) 2025 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#include "arm_compute/core/TensorFormat.h"
25+
#include "arm_compute/core/Types.h"
26+
#include "arm_compute/runtime/Tensor.h"
27+
#include "arm_compute/runtime/COOTensor.h"
28+
#include "tests/framework/Asserts.h"
29+
#include "tests/framework/Macros.h"
30+
#include "tests/framework/datasets/Datasets.h"
31+
#include "tests/validation/Validation.h"
32+
#include "tests/validation/Helpers.h"
33+
34+
#include "tests/NEON/Accessor.h"
35+
#include "tests/NEON/Helper.h"
36+
37+
#include <vector>
38+
39+
namespace arm_compute
40+
{
41+
namespace
42+
{
43+
bool are_values_equal(const uint8_t *a, const uint8_t *b, DataType dt, size_t element_size)
44+
{
45+
if(dt == DataType::F32)
46+
{
47+
float va = *reinterpret_cast<const float *>(a);
48+
float vb = *reinterpret_cast<const float *>(b);
49+
if(std::fabs(va - vb) > 0e-5f)
50+
{
51+
return false;
52+
}
53+
} else
54+
{
55+
if(std::memcmp(a, b, element_size) != 0)
56+
{
57+
return false;
58+
}
59+
}
60+
61+
return true;
62+
}
63+
64+
bool tensors_are_equal(const test::Accessor &a, const test::Accessor &b)
65+
{
66+
if(a.shape() != b.shape() || a.data_type() != b.data_type())
67+
return false;
68+
69+
const size_t element_size = a.element_size();
70+
Window window;
71+
window.use_tensor_dimensions(a.shape());
72+
73+
bool equal = true;
74+
75+
execute_window_loop(window, [&](const Coordinates &id)
76+
{
77+
const uint8_t *a_value = static_cast<const uint8_t *>(a(id));
78+
const uint8_t *b_value = static_cast<const uint8_t *>(b(id));
79+
80+
equal = are_values_equal(a_value, b_value, a.data_type(), element_size);
81+
});
82+
83+
return equal;
84+
}
85+
} // namespace
86+
87+
namespace test
88+
{
89+
namespace validation
90+
{
91+
TEST_SUITE(UNIT)
92+
TEST_SUITE(SparseTensor)
93+
94+
// clang-format off
95+
/** Validates TensorInfo Autopadding */
96+
DATA_TEST_CASE(ConvertCOOTensorToDense, framework::DatasetMode::ALL, combine(
97+
framework::dataset::make("TensorShape", {
98+
TensorShape(8U),
99+
TensorShape(3U, 3U),
100+
TensorShape(2U, 5U, 5U),
101+
TensorShape(4U, 2U, 2U, 9U)}),
102+
framework::dataset::make("TensorType", {
103+
DataType::U8,
104+
DataType::S8,
105+
DataType::U32,
106+
DataType::S32,
107+
DataType::F16,
108+
DataType::F32})
109+
), shape, type)
110+
{
111+
const auto t_info = TensorInfo(shape, 1, type, DataLayout::NCHW);
112+
auto t = create_tensor<Tensor>(t_info);
113+
auto t_zero = create_tensor<Tensor>(t_info);
114+
115+
t.allocator()->allocate();
116+
library->fill_tensor_sparse_random(Accessor(t), 0.2);
117+
118+
t_zero.allocator()->allocate();
119+
library->fill_static_values(Accessor(t_zero), std::vector<unsigned>(shape.total_size(), 0));
120+
121+
for(size_t sparse_dim = 1; sparse_dim <= shape.num_dimensions(); sparse_dim++)
122+
{
123+
auto st = t.to_coo_sparse(sparse_dim);
124+
bool is_sparse = st->info()->is_sparse();
125+
bool is_coo = st->info()->tensor_format() == TensorFormat::COO;
126+
size_t dense_dim = shape.num_dimensions() - sparse_dim;
127+
size_t is_hybrid = dense_dim > 0;
128+
auto td = st->to_dense();
129+
130+
ARM_COMPUTE_EXPECT(is_sparse, framework::LogLevel::ERRORS);
131+
ARM_COMPUTE_EXPECT(is_coo, framework::LogLevel::ERRORS);
132+
ARM_COMPUTE_EXPECT(st->sparse_dim() == sparse_dim, framework::LogLevel::ERRORS);
133+
ARM_COMPUTE_EXPECT(st->dense_dim() == dense_dim, framework::LogLevel::ERRORS);
134+
ARM_COMPUTE_EXPECT(st->is_hybrid() == is_hybrid, framework::LogLevel::ERRORS);
135+
ARM_COMPUTE_EXPECT(tensors_are_equal(Accessor(t), Accessor(*td)), framework::LogLevel::ERRORS);
136+
137+
auto st_zero = t_zero.to_coo_sparse(sparse_dim);
138+
auto td_zero = st_zero->to_dense();
139+
ARM_COMPUTE_EXPECT(tensors_are_equal(Accessor(t_zero), Accessor(*td_zero)), framework::LogLevel::ERRORS);
140+
}
141+
}
142+
// clang-format on
143+
// *INDENT-ON*
144+
145+
// clang-format off
146+
/** Validates TensorInfo Autopadding */
147+
DATA_TEST_CASE(ConvertCSRTensorToDense, framework::DatasetMode::ALL, combine(
148+
framework::dataset::make("TensorShape", {
149+
TensorShape(8U),
150+
TensorShape(3U, 3U),
151+
TensorShape(2U, 5U, 5U),
152+
TensorShape(4U, 2U, 2U, 9U)}),
153+
framework::dataset::make("TensorType", {
154+
DataType::U8,
155+
DataType::S8,
156+
DataType::U32,
157+
DataType::S32,
158+
DataType::F16,
159+
DataType::F32})
160+
), shape, type)
161+
{
162+
// Currently, CSRTensor only supports 2D tensors
163+
if(shape.num_dimensions() < 2)
164+
{
165+
return;
166+
}
167+
const TensorShape tensor_shape(shape[0], shape[1]);
168+
169+
const auto t_info = TensorInfo(tensor_shape, 1, type, DataLayout::NCHW);
170+
auto t = create_tensor<Tensor>(t_info);
171+
auto t_zero = create_tensor<Tensor>(t_info);
172+
173+
t.allocator()->allocate();
174+
library->fill_tensor_sparse_random(Accessor(t), 0.2);
175+
176+
t_zero.allocator()->allocate();
177+
library->fill_static_values(Accessor(t_zero), std::vector<unsigned>(tensor_shape.total_size(), 0));
178+
179+
auto st = t.to_csr_sparse();
180+
auto td = st->to_dense();
181+
bool is_sparse = st->info()->is_sparse();
182+
bool is_csr = st->info()->tensor_format() == TensorFormat::CSR;
183+
size_t sparse_dim = tensor_shape.num_dimensions();
184+
size_t dense_dim = tensor_shape.num_dimensions() - sparse_dim;
185+
size_t is_hybrid = dense_dim > 0;
186+
187+
ARM_COMPUTE_EXPECT(is_sparse, framework::LogLevel::ERRORS);
188+
ARM_COMPUTE_EXPECT(is_csr, framework::LogLevel::ERRORS);
189+
ARM_COMPUTE_EXPECT(st->sparse_dim() == sparse_dim, framework::LogLevel::ERRORS);
190+
ARM_COMPUTE_EXPECT(st->dense_dim() == dense_dim, framework::LogLevel::ERRORS);
191+
ARM_COMPUTE_EXPECT(st->is_hybrid() == is_hybrid, framework::LogLevel::ERRORS);
192+
ARM_COMPUTE_EXPECT(tensors_are_equal(Accessor(t), Accessor(*td)), framework::LogLevel::ERRORS);
193+
194+
auto st_zero = t_zero.to_coo_sparse(sparse_dim);
195+
auto td_zero = st_zero->to_dense();
196+
ARM_COMPUTE_EXPECT(tensors_are_equal(Accessor(t_zero), Accessor(*td_zero)), framework::LogLevel::ERRORS);
197+
}
198+
// clang-format on
199+
// *INDENT-ON*
200+
201+
TEST_SUITE_END() // SparseTensor
202+
TEST_SUITE_END() // UNIT
203+
204+
} // namespace validation
205+
} // namespace test
206+
} // namespace arm_compute

0 commit comments

Comments
 (0)