Skip to content

Commit 7b6e022

Browse files
hsharma35meta-codesync[bot]
authored andcommitted
Add C++ unit tests for cadence::quantized_conv2d_nhwc (#18479)
Summary: Pull Request resolved: #18479 Add test_op_quantized_conv2d_nhwc.cpp covering 17 test combinations from the Python test_quantized_conv2d_nhwc_out. Tests cover Conv2d (4D NHWC) and Conv1d (3D NLC) variants including basic, stride, padding, and depthwise cases, using kernel registry dispatch. Reviewed By: RahulC7 Differential Revision: D96507563
1 parent 8f1b5ee commit 7b6e022

2 files changed

Lines changed: 52 additions & 0 deletions

File tree

backends/cadence/generic/operators/op_quantized_conv2d.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,40 @@ Tensor& quantized_conv2d_nhwc_per_tensor_out(
953953
return out;
954954
}
955955

956+
Tensor& quantized_conv2d_depthwise_nhwc_out(
957+
ET_UNUSED KernelRuntimeContext& ctx,
958+
const Tensor& input,
959+
const Tensor& weight,
960+
const Tensor& bias,
961+
IntArrayRef stride,
962+
IntArrayRef padding,
963+
IntArrayRef dilation,
964+
int64_t groups,
965+
int64_t in_zero_point,
966+
int64_t weight_zero_point,
967+
double bias_scale,
968+
double output_scale,
969+
int64_t output_zero_point,
970+
ET_UNUSED int64_t out_multiplier,
971+
ET_UNUSED int64_t out_shift,
972+
Tensor& out) {
973+
quantized_conv2d_nhwc(
974+
input,
975+
weight,
976+
bias,
977+
stride,
978+
padding,
979+
dilation,
980+
static_cast<int16_t>(groups),
981+
static_cast<int32_t>(in_zero_point),
982+
static_cast<int32_t>(weight_zero_point),
983+
static_cast<float>(bias_scale),
984+
static_cast<float>(output_scale),
985+
static_cast<int32_t>(output_zero_point),
986+
out);
987+
return out;
988+
}
989+
956990
Tensor& quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out(
957991
ET_UNUSED KernelRuntimeContext& ctx,
958992
const Tensor& input,

backends/cadence/generic/operators/op_quantized_conv2d.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,24 @@ ::executorch::aten::Tensor& quantized_conv2d_nhwc_per_tensor_out(
207207
int64_t out_shift,
208208
Tensor& out);
209209

210+
::executorch::aten::Tensor& quantized_conv2d_depthwise_nhwc_out(
211+
KernelRuntimeContext& ctx,
212+
const Tensor& input,
213+
const Tensor& weight,
214+
const Tensor& bias,
215+
IntArrayRef stride,
216+
IntArrayRef padding,
217+
IntArrayRef dilation,
218+
int64_t groups,
219+
int64_t in_zero_point,
220+
int64_t weight_zero_point,
221+
double bias_scale,
222+
double output_scale,
223+
int64_t output_zero_point,
224+
int64_t out_multiplier,
225+
int64_t out_shift,
226+
Tensor& out);
227+
210228
::executorch::aten::Tensor&
211229
quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out(
212230
KernelRuntimeContext& ctx,

0 commit comments

Comments
 (0)