diff --git a/backends/qualcomm/_passes/seq_mse.py b/backends/qualcomm/_passes/seq_mse.py index e0ac0a82b0a..7ec399699dd 100644 --- a/backends/qualcomm/_passes/seq_mse.py +++ b/backends/qualcomm/_passes/seq_mse.py @@ -11,8 +11,10 @@ from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( PerBlockParamObserver, ) +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) from executorch.exir.pass_base import ExportPass, PassResult -from torchao.quantization.pt2e import PerChannelMinMaxObserver class SeqMseModule(torch.nn.Module): @@ -97,7 +99,7 @@ def _per_channel_qdq(self, scale, zero_point): def _fake_quant(self, scale, zero_point): dispatcher = { - PerChannelMinMaxObserver: self._per_channel_qdq, + PerChannelParamObserver: self._per_channel_qdq, PerBlockParamObserver: self._per_block_qdq, } return dispatcher[type(self.observer)](scale, zero_point) diff --git a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp index dbd629f81ba..8d7e8aae6bc 100644 --- a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp +++ b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp @@ -27,27 +27,37 @@ std::unique_ptr CreateQuantizationParamWrapper( quantize_param_wrapper = std::make_unique(); } else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { int32_t axis = quant_info["axis"].cast(); - std::vector scale_offset = - quant_info["scale_offset"].cast>(); - + auto so_arr = + quant_info["scale_offset"].cast>(); + auto so_buf = so_arr.request(); + const Qnn_ScaleOffset_t* so_ptr = + static_cast(so_buf.ptr); + std::vector scale_offset(so_ptr, so_ptr + so_buf.size); quantize_param_wrapper = std::make_unique( - axis, scale_offset); + axis, std::move(scale_offset)); } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { uint32_t bitwidth = quant_info["bitwidth"].cast(); int32_t axis = quant_info["axis"].cast(); - std::vector scale_offset = - quant_info["scale_offset"].cast>(); - uint32_t num_elements = scale_offset.size(); - std::vector scales; - std::vector offsets; - for (const auto& scale_offset : scale_offset) { - scales.push_back(scale_offset.scale); - offsets.push_back(scale_offset.offset); + auto so_arr = + quant_info["scale_offset"].cast>(); + auto so_buf = so_arr.request(); + const Qnn_ScaleOffset_t* so_ptr = + static_cast(so_buf.ptr); + uint32_t num_elements = static_cast(so_buf.size); + std::vector scales(num_elements); + std::vector offsets(num_elements); + for (uint32_t i = 0; i < num_elements; ++i) { + scales[i] = so_ptr[i].scale; + offsets[i] = so_ptr[i].offset; } quantize_param_wrapper = std::make_unique( - bitwidth, axis, num_elements, scales, offsets); + bitwidth, + axis, + num_elements, + std::move(scales), + std::move(offsets)); } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { uint32_t bitwidth = quant_info["bitwidth"].cast(); float scale = quant_info["scale"].cast(); @@ -62,8 +72,12 @@ std::unique_ptr CreateQuantizationParamWrapper( std::make_unique(scale, offset); } else if (encoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) { int32_t axis = quant_info["axis"].cast(); - std::vector scale_offset = - quant_info["block_scale_offset"].cast>(); + auto so_arr = + quant_info["block_scale_offset"].cast>(); + auto so_buf = so_arr.request(); + const Qnn_ScaleOffset_t* so_ptr = + static_cast(so_buf.ptr); + std::vector scale_offset(so_ptr, so_ptr + so_buf.size); uint32_t num_blocks_per_axis = quant_info["num_blocks_per_axis"].cast(); uint32_t block_scale_bitwidth = @@ -71,17 +85,19 @@ std::unique_ptr CreateQuantizationParamWrapper( Qnn_BlockwiseExpansionBlockScaleStorageType_t block_storage_type = quant_info["block_storage_type"] .cast(); - std::vector buf = - quant_info["block_scales"].cast>(); + py::array_t block_scales_arr = + quant_info["block_scales"].cast>(); + auto buf_info = block_scales_arr.request(); + const uint8_t* ptr = static_cast(buf_info.ptr); + std::vector block_scales_vec(ptr, ptr + buf_info.size); quantize_param_wrapper = std::make_unique( axis, - scale_offset, + std::move(scale_offset), num_blocks_per_axis, block_scale_bitwidth, block_storage_type, - buf.data(), - buf.size()); + std::move(block_scales_vec)); } else { QNN_EXECUTORCH_LOG_ERROR( "Unknown the encoding of quantization: %d", encoding); @@ -196,6 +212,7 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) { // TODO: Add related documents for configurations listed below using namespace qnn_delegate; PYBIND11_NUMPY_DTYPE(PyQnnTensorWrapper::EncodingData, scale, offset); + PYBIND11_NUMPY_DTYPE(Qnn_ScaleOffset_t, scale, offset); m.def("GetQNNCtxBinAlignment", &GetQNNCtxBinAlignment); m.def("GetQnnSdkBuildId", &GetQnnSdkBuildId); diff --git a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp index 81b6d04855c..3b5ffbc4e0d 100644 --- a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp +++ b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp @@ -77,6 +77,9 @@ std::unique_ptr CreateQuantizationParamWrapper( quantization.blockwiseExpansion->scaleOffsets, quantization.blockwiseExpansion->scaleOffsets + QNN_TENSOR_VER_PTR(tensor)->dimensions[ch_axis]); + std::vector block_scales( + quantization.blockwiseExpansion->blocksScale8, + quantization.blockwiseExpansion->blocksScale8 + block_scales_sz); quantize_param_wrapper = std::make_unique( quantization.blockwiseExpansion->axis, @@ -84,8 +87,7 @@ std::unique_ptr CreateQuantizationParamWrapper( quantization.blockwiseExpansion->numBlocksPerAxis, quantization.blockwiseExpansion->blockScaleBitwidth, quantization.blockwiseExpansion->blockScaleStorageType, - quantization.blockwiseExpansion->blocksScale8, - block_scales_sz); + block_scales); } else { QNN_EXECUTORCH_LOG_ERROR( "Unknown the encoding of quantization: %d", diff --git a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h index c2532c96388..86d137723aa 100644 --- a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h +++ b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h @@ -92,8 +92,8 @@ class BwAxisScaleOffsetQuantizeParamsWrapper final bitwidth_(bitwidth), axis_(axis), num_elements_(num_elements), - scales_(scales), - offsets_(offsets) {} + scales_(std::move(scales)), + offsets_(std::move(offsets)) {} BwAxisScaleOffsetQuantizeParamsWrapper( const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) @@ -235,12 +235,12 @@ class AxisScaleOffsetQuantizeParamsWrapper final public: explicit AxisScaleOffsetQuantizeParamsWrapper( std::int32_t axis, - const std::vector& scale_offsets) + std::vector scale_offsets) : QuantizeParamsWrapper( QNN_DEFINITION_DEFINED, QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET), axis_(axis), - scale_offsets_(scale_offsets) {} + scale_offsets_(std::move(scale_offsets)) {} AxisScaleOffsetQuantizeParamsWrapper( const AxisScaleOffsetQuantizeParamsWrapper& rhs) @@ -249,8 +249,6 @@ class AxisScaleOffsetQuantizeParamsWrapper final rhs.GetQuantizationEncoding()), axis_(rhs.axis_), scale_offsets_(rhs.scale_offsets_) {} - AxisScaleOffsetQuantizeParamsWrapper( - AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; AxisScaleOffsetQuantizeParamsWrapper& operator=( const AxisScaleOffsetQuantizeParamsWrapper& rhs) = delete; AxisScaleOffsetQuantizeParamsWrapper& operator=( @@ -286,21 +284,20 @@ class BlockwiseExpansionQuantizeParamsWrapper final public: explicit BlockwiseExpansionQuantizeParamsWrapper( std::int32_t axis, - const std::vector& scale_offsets, + std::vector scale_offsets, std::uint32_t num_blocks_per_axis, std::uint32_t block_scale_bitwidth, Qnn_BlockwiseExpansionBlockScaleStorageType_t storage_type, - const uint8_t* block_scales_ptr, - std::uint32_t block_scales_size) + std::vector block_scales) : QuantizeParamsWrapper( QNN_DEFINITION_DEFINED, QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION), axis_(axis), - scale_offsets_(scale_offsets), + scale_offsets_(std::move(scale_offsets)), num_blocks_per_axis_(num_blocks_per_axis), block_scale_bitwidth_(block_scale_bitwidth), block_storage_type_(storage_type), - block_scales_(block_scales_ptr, block_scales_ptr + block_scales_size) {} + block_scales_(std::move(block_scales)) {} BlockwiseExpansionQuantizeParamsWrapper( const BlockwiseExpansionQuantizeParamsWrapper& rhs) @@ -314,8 +311,6 @@ class BlockwiseExpansionQuantizeParamsWrapper final block_storage_type_(rhs.block_storage_type_), block_scales_(rhs.block_scales_) {} - BlockwiseExpansionQuantizeParamsWrapper( - BlockwiseExpansionQuantizeParamsWrapper&& rhs) = delete; BlockwiseExpansionQuantizeParamsWrapper& operator=( const BlockwiseExpansionQuantizeParamsWrapper& rhs) = delete; BlockwiseExpansionQuantizeParamsWrapper& operator=( diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index fa118829f00..1a5033b1b56 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy from typing import Any, Dict, Tuple import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager @@ -151,8 +150,12 @@ def _get_tensor(node, index): def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): import math - quant_config = copy.deepcopy(quant_attrs) - scales, scale_offset, quantized_scales = quant_attrs[QCOM_SCALE], [], [] + quant_config = { + QCOM_DTYPE: quant_attrs[QCOM_DTYPE], + QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN], + QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX], + } + scales = quant_attrs[QCOM_SCALE] # channel in observers defaults to zero num_channels = node.meta["val"].shape[0] user_0 = self.get_first_user(node) @@ -170,17 +173,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): PyQnnManager.Qnn_BlockwiseExpansionBlockScaleStorageType_t.QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8 ) + scale_offset_arr = np.empty( + num_channels, dtype=[("scale", np.float32), ("offset", np.int32)] + ) + # move channel axis to dim 0 for transpose_conv case + candidates = scales if ch_axis == 0 else scales.transpose(0, 1) + candidates = candidates.reshape(num_channels, -1) + # find max scale per channel + max_scales = candidates.amax(dim=-1) / num_steps + # quantize scales per channel + q_scales = torch.clamp( + input=torch.round(input=candidates / max_scales.unsqueeze(-1)), + min=1, + max=2**bitwidth_of_scale, + ).to(quant_scales_dtype) + # symmetric quantization is required for ch in range(num_channels): - candidates = scales[ch] if ch_axis == 0 else scales[:, ch, ...] - max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps - q_scales = torch.clamp( - input=torch.round(input=candidates / max_scale), - min=1, - max=2**bitwidth_of_scale, - ).to(quant_scales_dtype) - quantized_scales.append(q_scales) - # symmetric quantization is required - scale_offset.append(PyQnnManager.Qnn_ScaleOffset_t(max_scale, 0)) + scale_offset_arr[ch] = (float(max_scales[ch]), 0) # skip dequantize op, e.g. frozen_param -> dq -> conv2d user_0 = self.get_first_user(node) @@ -195,9 +204,9 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): else: raise AttributeError("undetermined axis for block quantization") - quant_config[QCOM_NUM_BLOCKS_PER_AXIS] = quantized_scales[0].shape.numel() - quant_config[QCOM_BLOCK_SCALE_OFFSET] = scale_offset - quant_config[QCOM_BLOCK_SCALES] = torch.cat(quantized_scales).detach().numpy() + quant_config[QCOM_NUM_BLOCKS_PER_AXIS] = q_scales.shape[1] + quant_config[QCOM_BLOCK_SCALE_OFFSET] = scale_offset_arr + quant_config[QCOM_BLOCK_SCALES] = q_scales.flatten().detach().numpy() # e.g. if use 16 bit for quantized scales, we need to expand 16 - 4 = 12 bits quant_config[QCOM_BLOCK_SCALE_BITWIDTH] = ( int(math.log2(torch.iinfo(quant_scales_dtype).max + 1)) - bitwidth_of_scale @@ -209,7 +218,11 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): ) def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): - quant_config = copy.deepcopy(quant_attrs) + quant_config = { + QCOM_DTYPE: quant_attrs[QCOM_DTYPE], + QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX], + QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN], + } scales = quant_attrs[QCOM_SCALES] zero_points = quant_attrs[QCOM_ZERO_POINTS] @@ -217,12 +230,11 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): zero_points ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}" - scale_offset = [] + scale_offset_arr = np.empty( + len(scales), dtype=[("scale", np.float32), ("offset", np.int32)] + ) for i in range(len(scales)): - # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h - scale_offset.append( - PyQnnManager.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) - ) + scale_offset_arr[i] = (float(scales[i]), int(-zero_points[i])) # skip dequantize op, e.g. frozen_param -> dq -> conv2d user_0 = self.get_first_user(node) @@ -232,7 +244,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): else: quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] - quant_config[QCOM_SCALE_OFFSET] = scale_offset + quant_config[QCOM_SCALE_OFFSET] = scale_offset_arr # special case for 4 bits if ( quant_config[QCOM_DTYPE] == torch.int8 @@ -249,7 +261,12 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): ) def make_qnn_per_tensor_config(self, quant_attrs: Dict): - quant_config = copy.deepcopy(quant_attrs) + quant_config = { + QCOM_DTYPE: quant_attrs[QCOM_DTYPE], + QCOM_SCALE: quant_attrs[QCOM_SCALE], + QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX], + QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN], + } # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT] # special case for 4 bits diff --git a/backends/qualcomm/quantizer/README.md b/backends/qualcomm/quantizer/README.md index 1f1868ff007..e03067265b5 100644 --- a/backends/qualcomm/quantizer/README.md +++ b/backends/qualcomm/quantizer/README.md @@ -128,7 +128,7 @@ def ptq_per_channel_quant_config( quant_max=torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args), ) bias_quantization_spec = _derived_bias_quant_spec @@ -142,7 +142,7 @@ def ptq_per_channel_quant_config( return quantization_config ``` -Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.
+Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelParamObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.
Now, we can start to fill in the function body: - Register annotator diff --git a/backends/qualcomm/quantizer/observers/concat_observer.py b/backends/qualcomm/quantizer/observers/concat_observer.py index cd2a1a99805..d52018531c5 100644 --- a/backends/qualcomm/quantizer/observers/concat_observer.py +++ b/backends/qualcomm/quantizer/observers/concat_observer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.utils.constants import DEFAULT_EPS_FP32 from torchao.quantization.pt2e import UniformQuantizationObserverBase @@ -23,7 +24,7 @@ def __init__( quant_min=None, quant_max=None, factory_kwargs=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 + eps=DEFAULT_EPS_FP32, is_dynamic=False, **kwargs, ) -> None: @@ -49,8 +50,9 @@ def __init__( def forward(self, x_orig): # calculate the min / max first - self.min_val = min(self.min_val, x_orig.min()) - self.max_val = max(self.max_val, x_orig.max()) + min_val, max_val = torch.aminmax(x_orig.detach()) + self.min_val = min(self.min_val, min_val) + self.max_val = max(self.max_val, max_val) if len(self.input_observers) == 0: # collect observers first if they are not cached diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index 13ab51008ed..2b2b0968723 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,6 +7,7 @@ from typing import Tuple import torch +from executorch.backends.qualcomm.utils.constants import DEFAULT_EPS_FP32 from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, @@ -23,7 +24,7 @@ def __init__( block_size: torch.Size, quant_min=None, quant_max=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 + eps=DEFAULT_EPS_FP32, **kwargs, ): super().__init__( @@ -99,7 +100,7 @@ def __init__( block_size: torch.Size = None, quant_min: int = None, quant_max: int = None, - eps: float = torch.finfo(torch.float32).eps, # noqa: B008 + eps: float = DEFAULT_EPS_FP32, **kwargs, ): super().__init__() diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py index 9f89f6b0e69..4f73a5bc021 100644 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -5,11 +5,15 @@ # LICENSE file in the root directory of this source tree. import torch -from torchao.quantization.pt2e import UniformQuantizationObserverBase +from executorch.backends.qualcomm.utils.constants import DEFAULT_EPS_FP32 +from torchao.quantization.pt2e import ( + PerChannelMinMaxObserver, + UniformQuantizationObserverBase, +) # TODO move to torch/ao/quantization/observer.py. -class PerChannelParamObserver(UniformQuantizationObserverBase): +class PerChannelParamObserverWithLossEvaluation(UniformQuantizationObserverBase): """ Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325 """ @@ -25,7 +29,7 @@ def __init__( quant_min=None, quant_max=None, factory_kwargs=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 + eps=DEFAULT_EPS_FP32, is_dynamic=False, **kwargs, ) -> None: @@ -111,3 +115,43 @@ def forward(self, x_orig): def calculate_qparams(self): return self._calculate_qparams(self.min_val, self.max_val) + + +class PerChannelParamObserver(PerChannelMinMaxObserver): + """ + Bypass redundant calibration for static parameters + """ + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=DEFAULT_EPS_FP32, + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + self.calibrated = False + + def forward(self, x_orig): + if self.calibrated: + return x_orig + + self.calibrated = True + return self._forward(x_orig) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index f281692a2d4..5adc8e526f1 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -13,6 +13,13 @@ PerBlockParamFakeQuantize, PerBlockParamObserver, ) +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) +from executorch.backends.qualcomm.utils.constants import ( + DEFAULT_EPS_16BIT, + DEFAULT_EPS_8BIT, +) from torch import Tensor from torch.fx import Node from torchao.quantization.pt2e import ( @@ -21,16 +28,12 @@ MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, - PerChannelMinMaxObserver, ) from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationSpec, ) -DEFAULT_EPS_8BIT = 0.0001 / 255 -DEFAULT_EPS_16BIT = 0.0001 / 65535 - @dataclass(eq=True) class QuantizationConfig: @@ -472,7 +475,7 @@ def get_ptq_per_channel_quant_config( quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=ch_axis, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args), ) bias_quantization_spec = _derived_bias_quant_spec diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 75723beebdc..9a01d4ed44c 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + # Qualcomm specific key # constants in backends/qualcomm/_passes & backends/qualcomm/builders @@ -53,3 +55,8 @@ HEXAGON_SDK_ROOT = "HEXAGON_SDK_ROOT" HEXAGON_TOOLS_ROOT = "HEXAGON_TOOLS_ROOT" DSP_VERSION = "DSP_VERSION" + +# Eps constants for quantizer +DEFAULT_EPS_8BIT = 0.0001 / 255 +DEFAULT_EPS_16BIT = 0.0001 / 65535 +DEFAULT_EPS_FP32 = torch.finfo(torch.float32).eps diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 87d90bb61b7..81a2c42c8d3 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -13,7 +13,7 @@ import torch from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( - PerChannelParamObserver, + PerChannelParamObserverWithLossEvaluation, ) from executorch.backends.qualcomm.quantizer.qconfig import ( _derived_bias_quant_spec, @@ -92,7 +92,7 @@ def get_custom_quantizer(backend, soc_model): quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( + observer_or_fake_quant_ctr=PerChannelParamObserverWithLossEvaluation.with_args( **{"steps": 100, "use_mse": True} ), ) diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 626be27be44..0a1b487354d 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -17,7 +17,7 @@ from datasets import load_dataset from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( - PerChannelParamObserver, + PerChannelParamObserverWithLossEvaluation, ) from executorch.backends.qualcomm.quantizer.qconfig import ( _derived_bias_quant_spec, @@ -96,7 +96,7 @@ def add_mse_weight_observer(quant_dtype, quantizer): quant_max=(7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max), qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( + observer_or_fake_quant_ctr=PerChannelParamObserverWithLossEvaluation.with_args( **{"steps": 200, "use_mse": True} ), ) diff --git a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py index 3a1bab412de..ef5bfc715c3 100644 --- a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py +++ b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( - PerChannelParamObserver, + PerChannelParamObserverWithLossEvaluation, ) from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -68,7 +68,7 @@ def forward( return self.model.forward(tokens, self.atten_mask) -class PerChannelMSEObserver(PerChannelParamObserver): +class PerChannelMSEObserver(PerChannelParamObserverWithLossEvaluation): def forward(self, x_orig): # since params are static, one calibration is enough