Skip to content

Commit 2ab5e79

Browse files
hsharma35meta-codesync[bot]
authored andcommitted
Promote pre-set spec.mem_id to AbsolutePlacementConstraint. (pytorch#18480)
Summary: Pull Request resolved: pytorch#18480 In plan_with_constraints(), specs that arrive with a pre-set mem_id (e.g. planned-temporary alloc nodes whose spec.mem_id is pinned by the AOT pass) are promoted to AbsolutePlacementConstraint before the solver runs. The planner then assigns only the offset within that tier, leaving mem_id intact. Specs that already have an explicit AbsolutePlacementConstraint are not double-promoted — the existing constraint wins. Adds test target test_memory_planning_algo with 5 regression tests: - unpinned spec placed freely in mem_id=1 - pinned spec stays in the requested tier - two specs pinned to the same tier get non-overlapping offsets - unpinned peers of pinned specs are not forced into the pinned tier - an externally-set AbsolutePlacementConstraint overrides spec.mem_id Differential Revision: D95413373
1 parent b9a6b84 commit 2ab5e79

4 files changed

Lines changed: 211 additions & 0 deletions

File tree

backends/cadence/aot/BUCK

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,7 @@ fbcode_target(_kind = runtime.python_library,
574574
srcs = [
575575
"memory_planning_algo.py",
576576
],
577+
typing = True,
577578
deps = [
578579
":memory_constraints",
579580
":pass_utils",
@@ -618,6 +619,23 @@ fbcode_target(_kind = runtime.python_library,
618619
],
619620
)
620621

622+
fbcode_target(_kind = python_unittest,
623+
name = "test_memory_planning_algo",
624+
srcs = [
625+
"tests/test_memory_planning_algo.py",
626+
],
627+
supports_static_listing = False,
628+
typing = True,
629+
deps = [
630+
":memory_constraints",
631+
":memory_planning",
632+
":memory_planning_algo",
633+
":utils",
634+
"//caffe2:torch",
635+
"//executorch/exir:tensor",
636+
],
637+
)
638+
621639
fbcode_target(_kind = python_unittest,
622640
name = "test_memory_passes",
623641
srcs = [

backends/cadence/aot/memory_constraints.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,17 @@ def add_absolute_placement_constraint(
324324
)
325325
)
326326

327+
def set_absolute_placement_constraint(
328+
self, spec: TensorSpec, constraint: AbsolutePlacementConstraint
329+
) -> None:
330+
"""Set an absolute placement constraint for `spec` by spec identity.
331+
332+
Overwrites any existing constraint for the same spec. Range validation
333+
of pinned_memory_id is the caller's responsibility (depends on the
334+
planner's MemoryConfig).
335+
"""
336+
self._absolute_placement_constraints[id(spec)] = constraint
337+
327338
def get_absolute_placement_constraint(
328339
self, spec: TensorSpec
329340
) -> Optional[AbsolutePlacementConstraint]:

backends/cadence/aot/memory_planning_algo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,40 @@ def plan_with_constraints(
283283
) -> None:
284284
"""Callable interface for ET memory planning."""
285285

286+
# Promote specs with a pre-set mem_id to AbsolutePlacementConstraint so
287+
# the planner honours the pinned memory tier and only assigns the offset.
288+
# This is used by planned-temporary alloc nodes whose spec.mem_id is set
289+
# by the AOT pass before planning runs.
290+
#
291+
# mem_id semantics:
292+
# None — not yet assigned (default); planner picks freely
293+
# <= 0 — sentinel for "unassigned/unpinned"; planner picks freely
294+
# [1, num_memories) — valid tier; promoted to constraint below
295+
#
296+
# Materialize to list because collect_specs_from_nodes returns a
297+
# generator and we iterate twice (promotion here, constraint
298+
# collection in spec_and_abs_constraints below).
299+
specs = list(specs)
300+
for spec in specs:
301+
if (
302+
spec.mem_id is not None
303+
and isinstance(spec.mem_id, int)
304+
and spec.mem_id > 0
305+
and placement_constraints.get_absolute_placement_constraint(spec)
306+
is None
307+
):
308+
num_memories = self.get_num_memories()
309+
assert 1 <= spec.mem_id < num_memories, (
310+
f"Pre-set spec.mem_id={spec.mem_id} is invalid. "
311+
f"Memory IDs must be in range [1, {num_memories}) for this planner configuration. "
312+
f"Check that the spec.mem_id was set correctly in the AOT pass, "
313+
f"or verify your MemoryConfig defines enough memory tiers."
314+
)
315+
placement_constraints.set_absolute_placement_constraint(
316+
spec,
317+
AbsolutePlacementConstraint(pinned_memory_id=spec.mem_id),
318+
)
319+
286320
spec_and_abs_constraints = {
287321
spec: placement_constraints.get_absolute_placement_constraint(spec)
288322
for spec in specs
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
import torch.fx
13+
from executorch.backends.cadence.aot.memory_constraints import MemConstraints
14+
from executorch.backends.cadence.aot.memory_planning import (
15+
PositionBasedGreedyWithHierarchy,
16+
)
17+
from executorch.backends.cadence.aot.memory_planning_algo import MemoryPlanningState
18+
from executorch.backends.cadence.aot.utils import MemoryConfig
19+
from executorch.exir.tensor import TensorSpec
20+
21+
22+
def _make_spec(shape: list[int], *, mem_id: int | None = None) -> TensorSpec:
23+
"""Create a TensorSpec for a uint8 tensor of given shape, optionally pre-pinning mem_id."""
24+
spec = TensorSpec(dtype=torch.uint8, shape=torch.Size(shape))
25+
# The planner's overlap checker requires valid lifetimes on every spec.
26+
spec.lifetime = [0, 1]
27+
if mem_id is not None:
28+
spec.mem_id = mem_id
29+
return spec
30+
31+
32+
def _make_algo_and_state(
33+
mem_sizes: list[int],
34+
) -> tuple[PositionBasedGreedyWithHierarchy, MemoryPlanningState, MemConstraints]:
35+
"""Build a 2-memory config planner (mem_id 1 = fast, 2 = slow) for tests."""
36+
config = MemoryConfig(mem_sizes)
37+
algo = PositionBasedGreedyWithHierarchy(config)
38+
state = MemoryPlanningState(config)
39+
constraints = MemConstraints()
40+
return algo, state, constraints
41+
42+
43+
class TestPinnedMemIdPromotion(unittest.TestCase):
44+
"""Tests for plan_with_constraints pre-set mem_id → AbsolutePlacementConstraint promotion."""
45+
46+
def _run(
47+
self,
48+
specs: list[TensorSpec],
49+
mem_sizes: list[int],
50+
) -> None:
51+
algo, state, constraints = _make_algo_and_state(mem_sizes)
52+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
53+
algo.plan_with_constraints(
54+
specs, gm, None, state, constraints
55+
) # pyre-ignore[6]
56+
57+
def test_spec_without_preset_mem_id_planned_freely(self) -> None:
58+
"""A spec with no pre-set mem_id is placed by the greedy algo in mem_id=1."""
59+
spec = _make_spec([512])
60+
self._run([spec], mem_sizes=[1024, 1024])
61+
self.assertIsNotNone(spec.mem_id)
62+
self.assertEqual(spec.mem_id, 1)
63+
self.assertIsNotNone(spec.mem_offset)
64+
65+
def test_spec_with_preset_mem_id_stays_in_that_memory(self) -> None:
66+
"""A spec with pre-set mem_id=2 stays in memory 2 even though memory 1 is faster."""
67+
spec = _make_spec([256])
68+
spec.mem_id = 2
69+
self._run([spec], mem_sizes=[4096, 4096])
70+
# mem_id must be preserved as 2
71+
self.assertEqual(spec.mem_id, 2)
72+
# Must have a valid offset assigned
73+
assert spec.mem_offset is not None
74+
assert spec.mem_offset >= 0
75+
76+
def test_preset_mem_id_offset_computed_by_planner(self) -> None:
77+
"""Two specs pinned to mem_id=2 get distinct non-overlapping offsets."""
78+
spec_a = _make_spec([100])
79+
spec_b = _make_spec([200])
80+
spec_a.mem_id = 2
81+
spec_b.mem_id = 2
82+
self._run([spec_a, spec_b], mem_sizes=[4096, 4096])
83+
self.assertEqual(spec_a.mem_id, 2)
84+
self.assertEqual(spec_b.mem_id, 2)
85+
# Offsets must not overlap: [a_start, a_end) ∩ [b_start, b_end) == ∅
86+
a_end = spec_a.mem_offset + spec_a.allocated_memory
87+
b_end = spec_b.mem_offset + spec_b.allocated_memory
88+
no_overlap = spec_a.mem_offset >= b_end or spec_b.mem_offset >= a_end
89+
self.assertTrue(no_overlap, f"Specs overlap: {spec_a} and {spec_b}")
90+
91+
def test_unpinned_spec_unaffected_by_pinned_peers(self) -> None:
92+
"""Specs without pre-set mem_id are not forced into the pinned tier."""
93+
pinned = _make_spec([128])
94+
pinned.mem_id = 2
95+
free = _make_spec([64]) # No preset; greedy should pick mem_id=1
96+
self._run([pinned, free], mem_sizes=[4096, 4096])
97+
self.assertEqual(pinned.mem_id, 2)
98+
# Greedy algo prefers mem_id=1 (faster) for unconstrained specs
99+
self.assertEqual(free.mem_id, 1)
100+
101+
def test_already_constrained_spec_not_overridden(self) -> None:
102+
"""A spec that already has an AbsolutePlacementConstraint is not double-promoted."""
103+
from executorch.backends.cadence.aot.memory_constraints import (
104+
AbsolutePlacementConstraint,
105+
)
106+
107+
spec = _make_spec([256])
108+
spec.mem_id = 1 # will be set but constraint added externally to mem_id=2
109+
110+
algo, state, constraints = _make_algo_and_state([4096, 4096])
111+
# Add an explicit constraint to mem_id=2 (overrides the spec.mem_id=1 preset)
112+
constraints.set_absolute_placement_constraint(
113+
spec, AbsolutePlacementConstraint(pinned_memory_id=2)
114+
)
115+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
116+
algo.plan_with_constraints(
117+
[spec], gm, None, state, constraints
118+
) # pyre-ignore[6]
119+
# The existing constraint (mem_id=2) takes precedence over spec.mem_id=1
120+
self.assertEqual(spec.mem_id, 2)
121+
122+
def test_mem_id_zero_treated_as_unpinned(self) -> None:
123+
"""A spec with mem_id=0 (sentinel for unassigned) should be planned freely."""
124+
spec = _make_spec([512], mem_id=0)
125+
self._run([spec], mem_sizes=[1024, 1024])
126+
# Greedy algo picks mem_id=1 for unconstrained specs
127+
self.assertEqual(spec.mem_id, 1)
128+
self.assertIsNotNone(spec.mem_offset)
129+
130+
def test_mem_id_out_of_range_raises(self) -> None:
131+
"""A spec with mem_id >= num_memories should raise AssertionError."""
132+
# With 2 memory tiers, valid mem_ids are 1 and 2; mem_id=3 is out of range.
133+
spec = _make_spec([256], mem_id=3)
134+
with self.assertRaises(AssertionError):
135+
self._run([spec], mem_sizes=[4096, 4096])
136+
137+
def test_mem_id_negative_treated_as_unpinned(self) -> None:
138+
"""A spec with negative mem_id should be treated as unpinned (not promoted)."""
139+
spec = _make_spec([256])
140+
spec.mem_id = -1
141+
self._run([spec], mem_sizes=[1024, 1024])
142+
# Negative mem_id is filtered out by the >0 check; greedy picks mem_id=1
143+
self.assertEqual(spec.mem_id, 1)
144+
self.assertIsNotNone(spec.mem_offset)
145+
146+
147+
if __name__ == "__main__":
148+
unittest.main()

0 commit comments

Comments
 (0)