Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/qualcomm/_passes/seq_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
PerBlockParamObserver,
)
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
PerChannelParamObserver,
)
from executorch.exir.pass_base import ExportPass, PassResult
from torchao.quantization.pt2e import PerChannelMinMaxObserver


class SeqMseModule(torch.nn.Module):
Expand Down Expand Up @@ -97,7 +99,7 @@ def _per_channel_qdq(self, scale, zero_point):

def _fake_quant(self, scale, zero_point):
dispatcher = {
PerChannelMinMaxObserver: self._per_channel_qdq,
PerChannelParamObserver: self._per_channel_qdq,
PerBlockParamObserver: self._per_block_qdq,
}
return dispatcher[type(self.observer)](scale, zero_point)
Expand Down
57 changes: 37 additions & 20 deletions backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,37 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
quantize_param_wrapper = std::make_unique<UndefinedQuantizeParamsWrapper>();
} else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
int32_t axis = quant_info["axis"].cast<int32_t>();
std::vector<Qnn_ScaleOffset_t> scale_offset =
quant_info["scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();

auto so_arr =
quant_info["scale_offset"].cast<py::array_t<Qnn_ScaleOffset_t>>();
auto so_buf = so_arr.request();
const Qnn_ScaleOffset_t* so_ptr =
static_cast<const Qnn_ScaleOffset_t*>(so_buf.ptr);
std::vector<Qnn_ScaleOffset_t> scale_offset(so_ptr, so_ptr + so_buf.size);
quantize_param_wrapper =
std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(
axis, scale_offset);
axis, std::move(scale_offset));
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) {
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
int32_t axis = quant_info["axis"].cast<int32_t>();
std::vector<Qnn_ScaleOffset_t> scale_offset =
quant_info["scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
uint32_t num_elements = scale_offset.size();
std::vector<float> scales;
std::vector<int32_t> offsets;
for (const auto& scale_offset : scale_offset) {
scales.push_back(scale_offset.scale);
offsets.push_back(scale_offset.offset);
auto so_arr =
quant_info["scale_offset"].cast<py::array_t<Qnn_ScaleOffset_t>>();
auto so_buf = so_arr.request();
const Qnn_ScaleOffset_t* so_ptr =
static_cast<const Qnn_ScaleOffset_t*>(so_buf.ptr);
uint32_t num_elements = static_cast<uint32_t>(so_buf.size);
std::vector<float> scales(num_elements);
std::vector<int32_t> offsets(num_elements);
for (uint32_t i = 0; i < num_elements; ++i) {
scales[i] = so_ptr[i].scale;
offsets[i] = so_ptr[i].offset;
}
quantize_param_wrapper =
std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(
bitwidth, axis, num_elements, scales, offsets);
bitwidth,
axis,
num_elements,
std::move(scales),
std::move(offsets));
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) {
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
float scale = quant_info["scale"].cast<float>();
Expand All @@ -62,26 +72,32 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
std::make_unique<ScaleOffsetQuantizeParamsWrapper>(scale, offset);
} else if (encoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) {
int32_t axis = quant_info["axis"].cast<int32_t>();
std::vector<Qnn_ScaleOffset_t> scale_offset =
quant_info["block_scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
auto so_arr =
quant_info["block_scale_offset"].cast<py::array_t<Qnn_ScaleOffset_t>>();
auto so_buf = so_arr.request();
const Qnn_ScaleOffset_t* so_ptr =
static_cast<const Qnn_ScaleOffset_t*>(so_buf.ptr);
std::vector<Qnn_ScaleOffset_t> scale_offset(so_ptr, so_ptr + so_buf.size);
uint32_t num_blocks_per_axis =
quant_info["num_blocks_per_axis"].cast<uint32_t>();
uint32_t block_scale_bitwidth =
quant_info["block_scale_bitwidth"].cast<uint32_t>();
Qnn_BlockwiseExpansionBlockScaleStorageType_t block_storage_type =
quant_info["block_storage_type"]
.cast<Qnn_BlockwiseExpansionBlockScaleStorageType_t>();
std::vector<uint8_t> buf =
quant_info["block_scales"].cast<std::vector<uint8_t>>();
py::array_t<uint8_t> block_scales_arr =
quant_info["block_scales"].cast<py::array_t<uint8_t>>();
auto buf_info = block_scales_arr.request();
const uint8_t* ptr = static_cast<const uint8_t*>(buf_info.ptr);
std::vector<uint8_t> block_scales_vec(ptr, ptr + buf_info.size);
quantize_param_wrapper =
std::make_unique<BlockwiseExpansionQuantizeParamsWrapper>(
axis,
scale_offset,
std::move(scale_offset),
num_blocks_per_axis,
block_scale_bitwidth,
block_storage_type,
buf.data(),
buf.size());
std::move(block_scales_vec));
} else {
QNN_EXECUTORCH_LOG_ERROR(
"Unknown the encoding of quantization: %d", encoding);
Expand Down Expand Up @@ -196,6 +212,7 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
// TODO: Add related documents for configurations listed below
using namespace qnn_delegate;
PYBIND11_NUMPY_DTYPE(PyQnnTensorWrapper::EncodingData, scale, offset);
PYBIND11_NUMPY_DTYPE(Qnn_ScaleOffset_t, scale, offset);

m.def("GetQNNCtxBinAlignment", &GetQNNCtxBinAlignment);
m.def("GetQnnSdkBuildId", &GetQnnSdkBuildId);
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,17 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
quantization.blockwiseExpansion->scaleOffsets,
quantization.blockwiseExpansion->scaleOffsets +
QNN_TENSOR_VER_PTR(tensor)->dimensions[ch_axis]);
std::vector<uint8_t> block_scales(
quantization.blockwiseExpansion->blocksScale8,
quantization.blockwiseExpansion->blocksScale8 + block_scales_sz);
quantize_param_wrapper =
std::make_unique<BlockwiseExpansionQuantizeParamsWrapper>(
quantization.blockwiseExpansion->axis,
scale_offsets,
quantization.blockwiseExpansion->numBlocksPerAxis,
quantization.blockwiseExpansion->blockScaleBitwidth,
quantization.blockwiseExpansion->blockScaleStorageType,
quantization.blockwiseExpansion->blocksScale8,
block_scales_sz);
block_scales);
} else {
QNN_EXECUTORCH_LOG_ERROR(
"Unknown the encoding of quantization: %d",
Expand Down
21 changes: 8 additions & 13 deletions backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class BwAxisScaleOffsetQuantizeParamsWrapper final
bitwidth_(bitwidth),
axis_(axis),
num_elements_(num_elements),
scales_(scales),
offsets_(offsets) {}
scales_(std::move(scales)),
offsets_(std::move(offsets)) {}

BwAxisScaleOffsetQuantizeParamsWrapper(
const BwAxisScaleOffsetQuantizeParamsWrapper& rhs)
Expand Down Expand Up @@ -235,12 +235,12 @@ class AxisScaleOffsetQuantizeParamsWrapper final
public:
explicit AxisScaleOffsetQuantizeParamsWrapper(
std::int32_t axis,
const std::vector<Qnn_ScaleOffset_t>& scale_offsets)
std::vector<Qnn_ScaleOffset_t> scale_offsets)
: QuantizeParamsWrapper(
QNN_DEFINITION_DEFINED,
QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET),
axis_(axis),
scale_offsets_(scale_offsets) {}
scale_offsets_(std::move(scale_offsets)) {}

AxisScaleOffsetQuantizeParamsWrapper(
const AxisScaleOffsetQuantizeParamsWrapper& rhs)
Expand All @@ -249,8 +249,6 @@ class AxisScaleOffsetQuantizeParamsWrapper final
rhs.GetQuantizationEncoding()),
axis_(rhs.axis_),
scale_offsets_(rhs.scale_offsets_) {}
AxisScaleOffsetQuantizeParamsWrapper(
AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
AxisScaleOffsetQuantizeParamsWrapper& operator=(
const AxisScaleOffsetQuantizeParamsWrapper& rhs) = delete;
AxisScaleOffsetQuantizeParamsWrapper& operator=(
Expand Down Expand Up @@ -286,21 +284,20 @@ class BlockwiseExpansionQuantizeParamsWrapper final
public:
explicit BlockwiseExpansionQuantizeParamsWrapper(
std::int32_t axis,
const std::vector<Qnn_ScaleOffset_t>& scale_offsets,
std::vector<Qnn_ScaleOffset_t> scale_offsets,
std::uint32_t num_blocks_per_axis,
std::uint32_t block_scale_bitwidth,
Qnn_BlockwiseExpansionBlockScaleStorageType_t storage_type,
const uint8_t* block_scales_ptr,
std::uint32_t block_scales_size)
std::vector<uint8_t> block_scales)
: QuantizeParamsWrapper(
QNN_DEFINITION_DEFINED,
QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION),
axis_(axis),
scale_offsets_(scale_offsets),
scale_offsets_(std::move(scale_offsets)),
num_blocks_per_axis_(num_blocks_per_axis),
block_scale_bitwidth_(block_scale_bitwidth),
block_storage_type_(storage_type),
block_scales_(block_scales_ptr, block_scales_ptr + block_scales_size) {}
block_scales_(std::move(block_scales)) {}

BlockwiseExpansionQuantizeParamsWrapper(
const BlockwiseExpansionQuantizeParamsWrapper& rhs)
Expand All @@ -314,8 +311,6 @@ class BlockwiseExpansionQuantizeParamsWrapper final
block_storage_type_(rhs.block_storage_type_),
block_scales_(rhs.block_scales_) {}

BlockwiseExpansionQuantizeParamsWrapper(
BlockwiseExpansionQuantizeParamsWrapper&& rhs) = delete;
BlockwiseExpansionQuantizeParamsWrapper& operator=(
const BlockwiseExpansionQuantizeParamsWrapper& rhs) = delete;
BlockwiseExpansionQuantizeParamsWrapper& operator=(
Expand Down
65 changes: 41 additions & 24 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Any, Dict, Tuple

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
Expand Down Expand Up @@ -151,8 +150,12 @@ def _get_tensor(node, index):
def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
import math

quant_config = copy.deepcopy(quant_attrs)
scales, scale_offset, quantized_scales = quant_attrs[QCOM_SCALE], [], []
quant_config = {
QCOM_DTYPE: quant_attrs[QCOM_DTYPE],
QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN],
QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX],
}
scales = quant_attrs[QCOM_SCALE]
# channel in observers defaults to zero
num_channels = node.meta["val"].shape[0]
user_0 = self.get_first_user(node)
Expand All @@ -170,17 +173,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
PyQnnManager.Qnn_BlockwiseExpansionBlockScaleStorageType_t.QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8
)

scale_offset_arr = np.empty(
num_channels, dtype=[("scale", np.float32), ("offset", np.int32)]
)
# move channel axis to dim 0 for transpose_conv case
candidates = scales if ch_axis == 0 else scales.transpose(0, 1)
candidates = candidates.reshape(num_channels, -1)
# find max scale per channel
max_scales = candidates.amax(dim=-1) / num_steps
# quantize scales per channel
q_scales = torch.clamp(
input=torch.round(input=candidates / max_scales.unsqueeze(-1)),
min=1,
max=2**bitwidth_of_scale,
).to(quant_scales_dtype)
# symmetric quantization is required
for ch in range(num_channels):
candidates = scales[ch] if ch_axis == 0 else scales[:, ch, ...]
max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps
q_scales = torch.clamp(
input=torch.round(input=candidates / max_scale),
min=1,
max=2**bitwidth_of_scale,
).to(quant_scales_dtype)
quantized_scales.append(q_scales)
# symmetric quantization is required
scale_offset.append(PyQnnManager.Qnn_ScaleOffset_t(max_scale, 0))
scale_offset_arr[ch] = (float(max_scales[ch]), 0)

# skip dequantize op, e.g. frozen_param -> dq -> conv2d
user_0 = self.get_first_user(node)
Expand All @@ -195,9 +204,9 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
else:
raise AttributeError("undetermined axis for block quantization")

quant_config[QCOM_NUM_BLOCKS_PER_AXIS] = quantized_scales[0].shape.numel()
quant_config[QCOM_BLOCK_SCALE_OFFSET] = scale_offset
quant_config[QCOM_BLOCK_SCALES] = torch.cat(quantized_scales).detach().numpy()
quant_config[QCOM_NUM_BLOCKS_PER_AXIS] = q_scales.shape[1]
quant_config[QCOM_BLOCK_SCALE_OFFSET] = scale_offset_arr
quant_config[QCOM_BLOCK_SCALES] = q_scales.flatten().detach().numpy()
# e.g. if use 16 bit for quantized scales, we need to expand 16 - 4 = 12 bits
quant_config[QCOM_BLOCK_SCALE_BITWIDTH] = (
int(math.log2(torch.iinfo(quant_scales_dtype).max + 1)) - bitwidth_of_scale
Expand All @@ -209,20 +218,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
)

def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)
quant_config = {
QCOM_DTYPE: quant_attrs[QCOM_DTYPE],
QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX],
QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN],
}

scales = quant_attrs[QCOM_SCALES]
zero_points = quant_attrs[QCOM_ZERO_POINTS]
assert len(scales) == len(
zero_points
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"

scale_offset = []
scale_offset_arr = np.empty(
len(scales), dtype=[("scale", np.float32), ("offset", np.int32)]
)
for i in range(len(scales)):
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
scale_offset.append(
PyQnnManager.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
)
scale_offset_arr[i] = (float(scales[i]), int(-zero_points[i]))

# skip dequantize op, e.g. frozen_param -> dq -> conv2d
user_0 = self.get_first_user(node)
Expand All @@ -232,7 +244,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
else:
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

quant_config[QCOM_SCALE_OFFSET] = scale_offset
quant_config[QCOM_SCALE_OFFSET] = scale_offset_arr
# special case for 4 bits
if (
quant_config[QCOM_DTYPE] == torch.int8
Expand All @@ -249,7 +261,12 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
)

def make_qnn_per_tensor_config(self, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)
quant_config = {
QCOM_DTYPE: quant_attrs[QCOM_DTYPE],
QCOM_SCALE: quant_attrs[QCOM_SCALE],
QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX],
QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN],
}
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
# special case for 4 bits
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/quantizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def ptq_per_channel_quant_config(
quant_max=torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args),
)

bias_quantization_spec = _derived_bias_quant_spec
Expand All @@ -142,7 +142,7 @@ def ptq_per_channel_quant_config(

return quantization_config
```
Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.<br/>
Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelParamObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.<br/>

Now, we can start to fill in the function body:
- Register annotator
Expand Down
8 changes: 5 additions & 3 deletions backends/qualcomm/quantizer/observers/concat_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.qualcomm.utils.constants import DEFAULT_EPS_FP32
from torchao.quantization.pt2e import UniformQuantizationObserverBase


Expand All @@ -23,7 +24,7 @@ def __init__(
quant_min=None,
quant_max=None,
factory_kwargs=None,
eps=torch.finfo(torch.float32).eps, # noqa: B008
eps=DEFAULT_EPS_FP32,
is_dynamic=False,
**kwargs,
) -> None:
Expand All @@ -49,8 +50,9 @@ def __init__(

def forward(self, x_orig):
# calculate the min / max first
self.min_val = min(self.min_val, x_orig.min())
self.max_val = max(self.max_val, x_orig.max())
min_val, max_val = torch.aminmax(x_orig.detach())
self.min_val = min(self.min_val, min_val)
self.max_val = max(self.max_val, max_val)

if len(self.input_observers) == 0:
# collect observers first if they are not cached
Expand Down
Loading
Loading