Skip to content

Commit 302a972

Browse files
author
Rolf Morel
committed
[transform_ext] Move populate_pattern op to dialects dir
Introduces op `transform_ext.populate_pattern TARGET_OP_KIND PAT_NAME PRIORITY` where patterns can be registered on `PopulatePatternOp` via its `name_to_pattern_rewrite` class member. Fixes #80.
1 parent c3dbe7c commit 302a972

9 files changed

Lines changed: 96 additions & 65 deletions

File tree

examples/workload/example.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from mlir.dialects import transform
1919
from mlir.execution_engine import ExecutionEngine
2020

21+
from lighthouse import dialects as lh_dialects
2122
from lighthouse.pipeline.helper import match
2223
from lighthouse.pipeline.opt import PassBundles, apply_bundle
2324

@@ -159,6 +160,8 @@ def schedule_modules(
159160

160161
if __name__ == "__main__":
161162
with ir.Context(), ir.Location.unknown():
163+
lh_dialects.register_and_load()
164+
162165
wload = ElementwiseSum(400, 400)
163166

164167
print(" Dump kernel ".center(60, "-"))

examples/workload/example_mlir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from mlir.dialects import func, linalg, arith, memref
2222
from mlir.execution_engine import ExecutionEngine
2323

24+
from lighthouse import dialects as lh_dialects
2425
from lighthouse.workload import execute, benchmark
2526
import lighthouse.utils as lh_utils
2627

@@ -195,6 +196,8 @@ def payload_module(self):
195196

196197
if __name__ == "__main__":
197198
with ir.Context(), ir.Location.unknown():
199+
lh_dialects.register_and_load()
200+
198201
wload = ElementwiseSumMLIRAlloc(400, 400)
199202

200203
print(" Dump kernel ".center(60, "-"))

examples/xegpu/matmul.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mlir import ir
2020
from mlir.execution_engine import ExecutionEngine
2121

22+
from lighthouse import dialects as lh_dialects
2223
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
2324
from lighthouse.utils.memref import to_ctype as memref_to_ctype
2425
from lighthouse.utils.numpy import numpy_to_ctype
@@ -360,6 +361,8 @@ def parse_cli():
360361
c_type = "f32"
361362

362363
with ir.Context(), ir.Location.unknown():
364+
lh_dialects.register_and_load()
365+
363366
wload = XeGPUMatMul(
364367
M=M,
365368
N=N,

examples/xegpu/mlp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from mlir import ir
2525
from mlir.execution_engine import ExecutionEngine
2626

27+
from lighthouse import dialects as lh_dialects
2728
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
2829
from lighthouse.utils.memref import to_ctype as memref_to_ctype
2930
from lighthouse.utils.numpy import numpy_to_ctype
@@ -375,6 +376,8 @@ def parse_cli():
375376
identity_weights = args.check_result
376377

377378
with ir.Context(), ir.Location.unknown():
379+
lh_dialects.register_and_load()
380+
378381
wload = XeGPUMLP(
379382
batch_size=args.batch_size,
380383
input_size=args.input_size,

lighthouse/dialects/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def register_and_load():
2+
from . import transform_ext
3+
4+
transform_ext.register_and_load()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from mlir import rewrite, ir
2+
from mlir.dialects import ext, transform
3+
4+
5+
def register_and_load(context=None):
6+
TransformExtensionDialect.load()
7+
8+
9+
class TransformExtensionDialect(ext.Dialect, name="transform_ext"):
10+
@classmethod
11+
def load(cls, *args, **kwargs):
12+
super().load(*args, **kwargs)
13+
for op_cls in cls.operations:
14+
if hasattr(op_cls, "attach_interface_impls"):
15+
op_cls.attach_interface_impls()
16+
17+
18+
class PopulatePatternOp(TransformExtensionDialect.Operation, name="populate_pattern"):
19+
"""An operation to populate a pattern set with a specific pattern.
20+
21+
To be used in the region of `transform.apply_patterns`."""
22+
23+
op_kind: ir.StringAttr
24+
pattern_name: ir.StringAttr
25+
priority: ir.IntegerAttr
26+
27+
# A mapping from pattern names to their corresponding rewrite functions.
28+
# This should be populated by the users of this operation. In effect serves
29+
# as a registry for rewrite patterns that can be applied by this operation.
30+
name_to_rewrite_pattern = {}
31+
32+
@classmethod
33+
def attach_interface_impls(cls, context=None):
34+
cls.PatternDescriptorOpInterfaceModel.attach(
35+
cls.OPERATION_NAME, context=context
36+
)
37+
38+
class PatternDescriptorOpInterfaceModel(transform.PatternDescriptorOpInterface):
39+
@staticmethod
40+
def populate_patterns(
41+
op: "PopulatePatternOp",
42+
patternset: rewrite.RewritePatternSet,
43+
) -> None:
44+
patternset.add(
45+
op.op_kind.value,
46+
op.name_to_rewrite_pattern[op.pattern_name.value],
47+
benefit=op.priority.value,
48+
)
49+
50+
51+
def populate_pattern(
52+
op_kind: str, pattern_name: str, priority: int
53+
) -> PopulatePatternOp:
54+
"""Camelcase constructor for PopulatePatternOp."""
55+
priority_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), priority)
56+
return PopulatePatternOp(
57+
op_kind=ir.StringAttr.get(op_kind),
58+
pattern_name=ir.StringAttr.get(pattern_name),
59+
priority=priority_attr,
60+
)

lighthouse/schedule/pattern_schedule.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

lighthouse/workload/runner.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import numpy as np
66
import os
77
from mlir import ir
8-
from mlir.dialects import func, arith, scf, memref
8+
from mlir.dialects import func, arith, scf, memref, transform
99
from mlir.execution_engine import ExecutionEngine
1010
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
11-
from lighthouse.schedule.pattern_schedule import pattern_rewrite_schedule
11+
from lighthouse.dialects.transform_ext import PopulatePatternOp, populate_pattern
12+
from lighthouse.schedule.utils import schedule_boilerplate
1213
from lighthouse.utils.mlir import func_cif, get_mlir_library_path
1314
from lighthouse.utils.memref import to_packed_args
1415
from lighthouse.workload import Workload
@@ -121,16 +122,24 @@ def bench(*args):
121122

122123

123124
def get_bench_wrapper_schedule(workload: Workload):
124-
return pattern_rewrite_schedule(
125-
{
126-
"func.func": bench_wrapper_pattern(
127-
workload.payload_function_name,
128-
workload.benchmark_function_name,
129-
)
130-
},
131-
"add_bench_pattern",
125+
PopulatePatternOp.name_to_rewrite_pattern["bench_wrapper"] = bench_wrapper_pattern(
126+
workload.payload_function_name,
127+
workload.benchmark_function_name,
132128
)
133129

130+
with schedule_boilerplate() as (schedule, named_seq):
131+
apply_patterns_op = transform.apply_patterns(named_seq.bodyTarget)
132+
with ir.InsertionPoint(apply_patterns_op.patterns):
133+
populate_pattern(
134+
op_kind="func.func",
135+
pattern_name="bench_wrapper",
136+
priority=1,
137+
)
138+
transform.yield_([named_seq.bodyTarget])
139+
140+
schedule.body.operations[0].verify()
141+
return schedule
142+
134143

135144
def benchmark(
136145
workload: Workload,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "lighthouse"
33
dynamic = ["version"]
44
requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging
55
dependencies = [
6-
"mlir-python-bindings==20260315+69780be1d"
6+
"mlir-python-bindings==20260316+f46a51538"
77
]
88

99
[dependency-groups]

0 commit comments

Comments
 (0)