Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_8a8w_qnn_ptq_config,
get_8a8w_qnn_qat_config,
get_ptq_per_channel_quant_config,
get_16a4w_qnn_ptq_config,
QuantizationConfig,
)
from executorch.backends.qualcomm.quantizer.rules import (
Expand All @@ -27,8 +28,35 @@
annotate_output_qspec,
QuantizationAnnotation,
SharedQuantizationSpec,
QuantizationSpec,
)

def custom_annotation_16a4w_layer_norm(gm):
use_16a4w_config = get_16a4w_qnn_ptq_config()
use_16a4w_config.weight = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=15,
qscheme=torch.per_tensor_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=use_16a4w_config.weight.observer_or_fake_quant_ctr,
)
for node in gm.graph.nodes:
if node.target != torch.ops.aten.layer_norm.default:
continue
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
input_qspec_map = {act_node: use_16a4w_config.input_activation, weight_node: use_16a4w_config.weight}
if len(node.args) > 2:
bias_node = node.args[3]
input_qspec_map[bias_node] = use_16a4w_config.bias

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=use_16a4w_config.output_activation,
_annotated=True,
)

def annotate_eurobert(gm: torch.fx.GraphModule):
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/oss_scripts/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ class InternVL3_1B(LLMModelConfig):
convert_weights = convert_internvl3_weights
transform_weight = False
instruct_model = True
num_sharding = 1
num_sharding = 8
masked_softmax = True
seq_mse_candidates = 0
r1 = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class MultiModalityConfig(ABC):

encoder_class: type
quant_recipe: EncoderQuantRecipe
num_sharding: int = 1

@abstractmethod
def create_encoder(self, config):
Expand Down Expand Up @@ -88,3 +89,4 @@ class InternVL3Encoder(VisionModalityConfig):
img_resized_w = 448
img_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
quant_recipe = InternVL3_Encoder_QuantRecipe
num_sharding = 8
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from torchao.quantization.pt2e import MinMaxObserver
from executorch.backends.qualcomm.quantizer.custom_annotation import (
custom_annotate_matmul_16a8w,
custom_annotation_16a4w_layer_norm,
)


class EncoderQuantRecipe:
Expand Down Expand Up @@ -52,6 +56,7 @@ def __init__(self, verbose: bool = False):
act_observer=MinMaxObserver,
granularity=QuantGranularity.PER_CHANNEL,
)
self.recipe.custom_quant_annotations.extend([custom_annotate_matmul_16a8w, custom_annotation_16a4w_layer_norm])


class SmolVLM_Encoder_QuantRecipe(EncoderQuantRecipe):
Expand Down
41 changes: 41 additions & 0 deletions examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,15 +974,56 @@ def __init__(
config.quant_recipe(True) if config.quant_recipe else None
)

# metadata
self.config = config
self.num_layers = auto_model.config.vision_config.num_hidden_layers

self.passes_job = get_capture_program_passes()
self.dep_table = get_passes_dependency_for_capture_program()

def _tag_ios(self, node, fixed_point_type):
quant_io_type = None

# tag sharding io
if exir_ops.edge.llama.fallback.default in [
u.target for u in list(node.users.keys())
] + [node.target]:
quant_io_type = fixed_point_type["io_type"]

return quant_io_type

def compile(self, request: Request):
if self.model is None:
return

request_data = request.method_data[self.modality]
# check if sharding required
if self.config.num_sharding > 1:
SplitGraph, setting = model_sharding.get_split_graph_pass(
self.num_layers,
shares=self.config.num_sharding,
pattern=r"vision_tower.encoder.layer.(\d+)",
)
self.passes_job[SplitGraph] = setting
self.dep_table[SplitGraph] = [FoldQDQ]
self.dep_table[TagQuantIO] = [SplitGraph]

if not request_data.skip_quantize:
fixed_point_type = {"io_type": torch.uint16}

# setup quantized IO
self.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
self.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
"get_quant_io_dtype_fn"
] = partial(self._tag_ios, fixed_point_type=fixed_point_type)

edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
module=self.model,
inputs=self.example_input,
compiler_specs=request_data.compile_spec,
dep_table=self.dep_table,
passes_job=self.passes_job,
skip_node_op_set={"llama.fallback.default"},
)
if self.control_args.verbose:
print_delegation_info(edge_prog_mgr.exported_program().graph_module)
Expand Down
19 changes: 12 additions & 7 deletions extension/llm/custom_ops/model_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ class SplitGraph(ExportPass):
not load all llama model in one pte.
"""

def __init__(self, shard_layers: List[int]):
def __init__(self, shard_layers: List[int], pattern=r"layers.(\d+)"):
super().__init__()
self.shard_layers = shard_layers
self.pattern = pattern

def _insert_fallback_op(
self, graph_module: torch.fx.GraphModule
Expand All @@ -62,7 +63,6 @@ def _insert_fallback_op(
The second partition will contain layers [4, 8).
The third partition will contain layers [8, 12) and output.
"""
pattern = r"layers.(\d+)"
prev_node = None
prev_layer = None
for node in graph_module.graph.nodes:
Expand All @@ -72,7 +72,7 @@ def _insert_fallback_op(
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
# Search which layer this node belongs to
match = re.search(pattern, full_qualified_name)
match = re.search(self.pattern, full_qualified_name)
if match is None:
continue

Expand Down Expand Up @@ -103,15 +103,20 @@ def call(self, graph_module: torch.fx.GraphModule):
return PassResult(graph_module, True)


def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int):
def split_graph(
edge_program: ExportedProgram, num_layers: int, shares: int, pattern=r"layers.(\d+)"
):
graph_module = edge_program.graph_module
shard_layers = list(range(0, num_layers, int(num_layers / shares)))
return SplitGraph(shard_layers)(graph_module)
return SplitGraph(shard_layers, pattern=pattern)(graph_module)


def get_split_graph_pass(num_layers: int, shares: int):
def get_split_graph_pass(num_layers: int, shares: int, pattern=r"layers.(\d+)"):
shard_layers = list(range(0, num_layers, int(num_layers / shares)))
return SplitGraph, {
QCOM_PASS_ACTIVATE_KEY: True,
QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: {"shard_layers": shard_layers},
QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: {
"shard_layers": shard_layers,
"pattern": pattern,
},
}
Loading