Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions src/op_convert_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,10 @@ void ConvertTo::operator()(hipStream_t stream, const Tensor &input, const Tensor
CHECK_TENSOR_CHANNELS(input, 1, 2, 3, 4);

eDataType input_dtype = input.dtype().etype();
int64_t channels = input.shape(input.layout().channels_index());

// Validate output tensor
CHECK_TENSOR_COMPARISON(input.device() == output.device());
CHECK_TENSOR_COMPARISON(output.shape(output.layout().channels_index()) == channels);
CHECK_TENSOR_COMPARISON(output.layout() == input.layout());
if (output.layout().batch_index() != -1) {
CHECK_TENSOR_COMPARISON(output.shape(output.layout().batch_index()) ==
input.shape(input.layout().batch_index()));
}
CHECK_TENSOR_COMPARISON(input.shape() == output.shape());

// Select kernel dispatcher based on a base input datatype.
// clang-format off
Expand Down
59 changes: 59 additions & 0 deletions tests/roccv/cpp/src/tests/operators/test_op_convert_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,72 @@ void TestCorrectness(int batchSize, int width, int height, ImageFormat inFormat,
CompareVectorsNear(result, ref, 1.0E-4);
}

void TestNegativeConvertTo() {
TensorShape validShape(TensorLayout(eTensorLayout::TENSOR_LAYOUT_NHWC), {1, 1, 1, 1});
Tensor validGPUTensor(validShape, DataType(eDataType::DATA_TYPE_U8), eDeviceType::GPU);
Tensor validCPUTensor(validShape, DataType(eDataType::DATA_TYPE_U8), eDeviceType::CPU);
ConvertTo op;

{
// Test output tensor on CPU for GPU operation
EXPECT_EXCEPTION(op(nullptr, validGPUTensor, validCPUTensor, 1.0, 0.0, eDeviceType::GPU),
eStatusType::INVALID_COMBINATION);
}

{
// Test input tensor on CPU for GPU operation
EXPECT_EXCEPTION(op(nullptr, validCPUTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU),
eStatusType::INVALID_OPERATION);
}

{
// Test unsupported layout
TensorShape invalidLayoutShape(TensorLayout(eTensorLayout::TENSOR_LAYOUT_NC), {1, 1});
Tensor invalidTensor(invalidLayoutShape, DataType(eDataType::DATA_TYPE_U8), eDeviceType::GPU);
EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU),
eStatusType::INVALID_COMBINATION);
}

{
// Test unsupported data type
Tensor invalidTensor(validGPUTensor.shape(), DataType(eDataType::DATA_TYPE_U32), eDeviceType::GPU);
EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU), eStatusType::NOT_IMPLEMENTED);
}

{
// Test input/output channel mismatch
Tensor invalidTensor(TensorShape(validGPUTensor.layout(), {1, 1, 1, 2}), DataType(eDataType::DATA_TYPE_U8),
eDeviceType::GPU);
EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU),
eStatusType::INVALID_COMBINATION);
}

{
// Test input/output width/height mismatch
Tensor invalidTensor(TensorShape(validGPUTensor.layout(), {1, 2, 2, 1}), DataType(eDataType::DATA_TYPE_U8),
eDeviceType::GPU);
EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU),
eStatusType::INVALID_COMBINATION);
}

{
// Test input/output batch mismatch
Tensor invalidTensor(TensorShape(validGPUTensor.layout(), {2, 1, 1, 1}), DataType(eDataType::DATA_TYPE_U8),
eDeviceType::GPU);
EXPECT_EXCEPTION(op(nullptr, invalidTensor, validGPUTensor, 1.0, 0.0, eDeviceType::GPU),
eStatusType::INVALID_COMBINATION);
}
}

} // namespace

int main(int argc, char** argv) {
(void)argc;
(void)argv;
TEST_CASES_BEGIN();

TEST_CASE(TestNegativeConvertTo());

// CPU correctness tests
// 1 Channel
TEST_CASE((TestCorrectness<uchar1, uchar1>(1, 480, 360, FMT_U8, FMT_U8, 1.2, 10.2, eDeviceType::CPU)));
Expand Down