From e72e3441a083d9097569fa4e84134ae033e511c7 Mon Sep 17 00:00:00 2001 From: daniellegillai Date: Fri, 29 May 2026 15:08:14 -0700 Subject: [PATCH] convert-to shape validation added --- src/op_convert_to.cpp | 8 +-- .../tests/operators/test_op_convert_to.cpp | 59 +++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/src/op_convert_to.cpp b/src/op_convert_to.cpp index 3baa4577..f814498e 100644 --- a/src/op_convert_to.cpp +++ b/src/op_convert_to.cpp @@ -122,16 +122,10 @@ void ConvertTo::operator()(hipStream_t stream, const Tensor &input, const Tensor CHECK_TENSOR_CHANNELS(input, 1, 2, 3, 4); eDataType input_dtype = input.dtype().etype(); - int64_t channels = input.shape(input.layout().channels_index()); // Validate output tensor CHECK_TENSOR_COMPARISON(input.device() == output.device()); - CHECK_TENSOR_COMPARISON(output.shape(output.layout().channels_index()) == channels); - CHECK_TENSOR_COMPARISON(output.layout() == input.layout()); - if (output.layout().batch_index() != -1) { - CHECK_TENSOR_COMPARISON(output.shape(output.layout().batch_index()) == - input.shape(input.layout().batch_index())); - } + CHECK_TENSOR_COMPARISON(input.shape() == output.shape()); // Select kernel dispatcher based on a base input datatype. // clang-format off diff --git a/tests/roccv/cpp/src/tests/operators/test_op_convert_to.cpp b/tests/roccv/cpp/src/tests/operators/test_op_convert_to.cpp index eaa714a9..6f14bdbd 100644 --- a/tests/roccv/cpp/src/tests/operators/test_op_convert_to.cpp +++ b/tests/roccv/cpp/src/tests/operators/test_op_convert_to.cpp @@ -112,6 +112,63 @@ void TestCorrectness(int batchSize, int width, int height, ImageFormat inFormat, CompareVectorsNear(result, ref, 1.0E-4); } +void TestNegativeConvertTo() { + TensorShape validShape(TensorLayout(eTensorLayout::TENSOR_LAYOUT_NHWC), {1, 1, 1, 1}); + Tensor validGPUTensor(validShape, DataType(eDataType::DATA_TYPE_U8), eDeviceType::GPU); + Tensor validCPUTensor(validShape, DataType(eDataType::DATA_TYPE_U8), eDeviceType::CPU); + ConvertTo op; + + { + // Test output tensor on CPU for GPU operation + EXPECT_EXCEPTION(op(nullptr, validGPUTensor, validCPUTensor, 1.0, 0.0, eDeviceType::GPU), + eStatusType::INVALID_COMBINATION); + } + + { + // Test input tensor on CPU for GPU operation + EXPECT_EXCEPTION(op(nullptr, validCPUTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), + eStatusType::INVALID_OPERATION); + } + + { + // Test unsupported layout + TensorShape invalidLayoutShape(TensorLayout(eTensorLayout::TENSOR_LAYOUT_NC), {1, 1}); + Tensor invalidTensor(invalidLayoutShape, DataType(eDataType::DATA_TYPE_U8), eDeviceType::GPU); + EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), + eStatusType::INVALID_COMBINATION); + } + + { + // Test unsupported data type + Tensor invalidTensor(validGPUTensor.shape(), DataType(eDataType::DATA_TYPE_U32), eDeviceType::GPU); + EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), eStatusType::NOT_IMPLEMENTED); + } + + { + // Test input/output channel mismatch + Tensor invalidTensor(TensorShape(validGPUTensor.layout(), {1, 1, 1, 2}), DataType(eDataType::DATA_TYPE_U8), + eDeviceType::GPU); + EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), + eStatusType::INVALID_COMBINATION); + } + + { + // Test input/output width/height mismatch + Tensor invalidTensor(TensorShape(validGPUTensor.layout(), {1, 2, 2, 1}), DataType(eDataType::DATA_TYPE_U8), + eDeviceType::GPU); + EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), + eStatusType::INVALID_COMBINATION); + } + + { + // Test input/output batch mismatch + Tensor invalidTensor(TensorShape(validGPUTensor.layout(), {2, 1, 1, 1}), DataType(eDataType::DATA_TYPE_U8), + eDeviceType::GPU); + EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), + eStatusType::INVALID_COMBINATION); + } +} + } // namespace int main(int argc, char** argv) { @@ -119,6 +176,8 @@ int main(int argc, char** argv) { (void)argv; TEST_CASES_BEGIN(); + TEST_CASE(TestNegativeConvertTo()); + // CPU correctness tests // 1 Channel TEST_CASE((TestCorrectness(1, 480, 360, FMT_U8, FMT_U8, 1.2, 10.2, eDeviceType::CPU)));