diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 54456877c1f..84e128a2d7f 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -13,6 +13,7 @@ get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, + get_16a4w_qnn_ptq_config, QuantizationConfig, ) from executorch.backends.qualcomm.quantizer.rules import ( @@ -27,8 +28,35 @@ annotate_output_qspec, QuantizationAnnotation, SharedQuantizationSpec, + QuantizationSpec, ) +def custom_annotation_16a4w_layer_norm(gm): + use_16a4w_config = get_16a4w_qnn_ptq_config() + use_16a4w_config.weight = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=15, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=use_16a4w_config.weight.observer_or_fake_quant_ctr, + ) + for node in gm.graph.nodes: + if node.target != torch.ops.aten.layer_norm.default: + continue + act_node = node.args[0] + weight_node = node.args[2] + bias_node = None + input_qspec_map = {act_node: use_16a4w_config.input_activation, weight_node: use_16a4w_config.weight} + if len(node.args) > 2: + bias_node = node.args[3] + input_qspec_map[bias_node] = use_16a4w_config.bias + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=use_16a4w_config.output_activation, + _annotated=True, + ) def annotate_eurobert(gm: torch.fx.GraphModule): """ diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index b28b4752c12..0fbb18fb6da 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -528,7 +528,7 @@ class InternVL3_1B(LLMModelConfig): convert_weights = convert_internvl3_weights transform_weight = False instruct_model = True - num_sharding = 1 + num_sharding = 8 masked_softmax = True seq_mse_candidates = 0 r1 = False diff --git a/examples/qualcomm/oss_scripts/llama/encoder/encoder_config.py b/examples/qualcomm/oss_scripts/llama/encoder/encoder_config.py index b8e32904bbf..cdb167247c1 100644 --- a/examples/qualcomm/oss_scripts/llama/encoder/encoder_config.py +++ b/examples/qualcomm/oss_scripts/llama/encoder/encoder_config.py @@ -31,6 +31,7 @@ class MultiModalityConfig(ABC): encoder_class: type quant_recipe: EncoderQuantRecipe + num_sharding: int = 1 @abstractmethod def create_encoder(self, config): @@ -88,3 +89,4 @@ class InternVL3Encoder(VisionModalityConfig): img_resized_w = 448 img_url = "http://images.cocodataset.org/val2017/000000039769.jpg" quant_recipe = InternVL3_Encoder_QuantRecipe + num_sharding = 8 diff --git a/examples/qualcomm/oss_scripts/llama/encoder/encoder_quant_recipe.py b/examples/qualcomm/oss_scripts/llama/encoder/encoder_quant_recipe.py index 58d34af1d45..db81fd40a07 100644 --- a/examples/qualcomm/oss_scripts/llama/encoder/encoder_quant_recipe.py +++ b/examples/qualcomm/oss_scripts/llama/encoder/encoder_quant_recipe.py @@ -13,6 +13,10 @@ ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from torchao.quantization.pt2e import MinMaxObserver +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + custom_annotate_matmul_16a8w, + custom_annotation_16a4w_layer_norm, +) class EncoderQuantRecipe: @@ -52,6 +56,7 @@ def __init__(self, verbose: bool = False): act_observer=MinMaxObserver, granularity=QuantGranularity.PER_CHANNEL, ) + self.recipe.custom_quant_annotations.extend([custom_annotate_matmul_16a8w, custom_annotation_16a4w_layer_norm]) class SmolVLM_Encoder_QuantRecipe(EncoderQuantRecipe): diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 4d63b2471ff..d91ff4618dd 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -974,15 +974,56 @@ def __init__( config.quant_recipe(True) if config.quant_recipe else None ) + # metadata + self.config = config + self.num_layers = auto_model.config.vision_config.num_hidden_layers + + self.passes_job = get_capture_program_passes() + self.dep_table = get_passes_dependency_for_capture_program() + + def _tag_ios(self, node, fixed_point_type): + quant_io_type = None + + # tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + quant_io_type = fixed_point_type["io_type"] + + return quant_io_type + def compile(self, request: Request): if self.model is None: return request_data = request.method_data[self.modality] + # check if sharding required + if self.config.num_sharding > 1: + SplitGraph, setting = model_sharding.get_split_graph_pass( + self.num_layers, + shares=self.config.num_sharding, + pattern=r"vision_tower.encoder.layer.(\d+)", + ) + self.passes_job[SplitGraph] = setting + self.dep_table[SplitGraph] = [FoldQDQ] + self.dep_table[TagQuantIO] = [SplitGraph] + + if not request_data.skip_quantize: + fixed_point_type = {"io_type": torch.uint16} + + # setup quantized IO + self.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + self.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(self._tag_ios, fixed_point_type=fixed_point_type) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( module=self.model, inputs=self.example_input, compiler_specs=request_data.compile_spec, + dep_table=self.dep_table, + passes_job=self.passes_job, + skip_node_op_set={"llama.fallback.default"}, ) if self.control_args.verbose: print_delegation_info(edge_prog_mgr.exported_program().graph_module) diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py index df87274a115..6838b0958a2 100644 --- a/extension/llm/custom_ops/model_sharding.py +++ b/extension/llm/custom_ops/model_sharding.py @@ -47,9 +47,10 @@ class SplitGraph(ExportPass): not load all llama model in one pte. """ - def __init__(self, shard_layers: List[int]): + def __init__(self, shard_layers: List[int], pattern=r"layers.(\d+)"): super().__init__() self.shard_layers = shard_layers + self.pattern = pattern def _insert_fallback_op( self, graph_module: torch.fx.GraphModule @@ -62,7 +63,6 @@ def _insert_fallback_op( The second partition will contain layers [4, 8). The third partition will contain layers [8, 12) and output. """ - pattern = r"layers.(\d+)" prev_node = None prev_layer = None for node in graph_module.graph.nodes: @@ -72,7 +72,7 @@ def _insert_fallback_op( module_values_list = list(node.meta["nn_module_stack"].values()) full_qualified_name = module_values_list[-1][0] # Search which layer this node belongs to - match = re.search(pattern, full_qualified_name) + match = re.search(self.pattern, full_qualified_name) if match is None: continue @@ -103,15 +103,20 @@ def call(self, graph_module: torch.fx.GraphModule): return PassResult(graph_module, True) -def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int): +def split_graph( + edge_program: ExportedProgram, num_layers: int, shares: int, pattern=r"layers.(\d+)" +): graph_module = edge_program.graph_module shard_layers = list(range(0, num_layers, int(num_layers / shares))) - return SplitGraph(shard_layers)(graph_module) + return SplitGraph(shard_layers, pattern=pattern)(graph_module) -def get_split_graph_pass(num_layers: int, shares: int): +def get_split_graph_pass(num_layers: int, shares: int, pattern=r"layers.(\d+)"): shard_layers = list(range(0, num_layers, int(num_layers / shares))) return SplitGraph, { QCOM_PASS_ACTIVATE_KEY: True, - QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: {"shard_layers": shard_layers}, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: { + "shard_layers": shard_layers, + "pattern": pattern, + }, }