Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,17 @@ fbcode_target(_kind = runtime.python_library,
"//caffe2:torch",
],
)

fbcode_target(_kind = runtime.python_library,
name = "propagate_device_pass",
srcs = [
"propagate_device_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:delegate",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:schema",
"//executorch/exir:tensor",
],
)
210 changes: 210 additions & 0 deletions exir/passes/propagate_device_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
from typing import Optional

import executorch.exir.schema as schema

import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.tensor import TensorSpec
from torch.fx.passes.infra.pass_base import PassBase, PassResult

logger: logging.Logger = logging.getLogger(__name__)

# CompileSpec key convention for specifying the target device.
# Partitioners that target a specific device should include a CompileSpec entry
# with this key and a value encoding the device string (e.g., b"cuda:0").
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"


def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
"""
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
(DeviceType, device_index).

The type portion is matched case-insensitively against schema.DeviceType
member names (e.g., "cpu", "cuda"). Raises ValueError for unknown types.
"""
device_str = value.decode("utf-8").strip().lower()
if ":" in device_str:
type_str, index_str = device_str.split(":", 1)
device_index = int(index_str)
else:
type_str = device_str
device_index = 0
device_type = next(
(dt for dt in schema.DeviceType if dt.name.lower() == type_str),
None,
)
if device_type is None:
valid = ", ".join(dt.name for dt in schema.DeviceType)
raise ValueError(f"Unknown device type '{type_str}'. Valid types: {valid}")
return device_type, device_index


def _get_lowered_module(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this type of util really should be placed in a single spot. There are other things like this in the passes. Lets take it as a follow up to have claude just search for generic utils like this and centralize them

graph_module: torch.fx.GraphModule,
delegate_call_node: torch.fx.Node,
) -> Optional[LoweredBackendModule]:
"""
Given an executorch_call_delegate node, retrieve the associated
LoweredBackendModule from the graph module.
The first argument to executorch_call_delegate is a get_attr node
whose target names the LoweredBackendModule attribute.
"""
if len(delegate_call_node.args) < 1:
return None
lowered_node = delegate_call_node.args[0]
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
return None
lowered_module = getattr(graph_module, lowered_node.target, None)
if isinstance(lowered_module, LoweredBackendModule):
return lowered_module
return None


def _get_target_device_from_compile_specs(
lowered_module: LoweredBackendModule,
) -> Optional[tuple[schema.DeviceType, int]]:
"""
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
the corresponding (DeviceType, device_index), or None if not found.
"""
for spec in lowered_module.compile_specs:
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
return _parse_device_spec_value(spec.value)
return None


def _set_device_on_spec(
spec: TensorSpec,
device_type: schema.DeviceType,
device_index: int = 0,
) -> None:
"""Set the device attribute on a TensorSpec."""
spec.device = device_type
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these fields already in the TensorSpec class definition? Are they initialized to just cpu and 0?

spec.device_index = device_index


def _tag_specs_with_device(
specs: object,
device_type: schema.DeviceType,
device_index: int = 0,
) -> bool:
"""Apply device annotation to a TensorSpec or a collection of TensorSpecs.

Args:
specs: A TensorSpec, a tuple/list of TensorSpecs, or None.
device_type: The target device type to set.
device_index: The device index (e.g., 0 for cuda:0, 1 for cuda:1).

Returns:
True if any spec was modified, False otherwise.
"""
if specs is None:
return False
if isinstance(specs, TensorSpec):
_set_device_on_spec(specs, device_type, device_index)
return True
if isinstance(specs, (tuple, list)):
changed = False
for s in specs:
if isinstance(s, TensorSpec):
_set_device_on_spec(s, device_type, device_index)
changed = True
return changed
return False


class PropagateDevicePass(PassBase):
"""
After to_backend, walk the graph and set device metadata on TensorSpecs
based on partitioner-assigned delegation info.

Rules:
1. Delegated nodes: Input and output tensors of a delegate call are marked
with the target device derived from the delegate's CompileSpec
(key="target_device").
2. Non-delegated nodes: Remain on CPU (default).
3. Getitem nodes that extract from a delegate call inherit the device from
the delegate call's output spec at the corresponding index.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
changed = False
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
lowered_module = _get_lowered_module(graph_module, node)
if lowered_module is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw here no?

continue

result = _get_target_device_from_compile_specs(lowered_module)
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This effectively assumes that we know the device 'name' AoT. In theory, we can have a multi-device delegate then the runtime might interpret this name differently and that can cause some confusion i.e cuda:0 device on Metal.

I am not sure about using generic names like 'gpu' but also not sure about following PyTorch's eager/jit style naming convention where you won't switch devices underneath.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I have your suggestions on the executorch device name?

Currently we set up the device name AOT and intentionally decouple dour device attribute with pytorch/pytorch device concept; we created a enum in the etensor schema for all devices we are supporting right now. In this way we can support as much as device as we want.

For the situaton you mentioned, if other backend like vulken need its own gpu device, they should add a new one to the enum. We should avoid using generic names like 'gpu'.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi-device graph serialization will necessitate multiple graphs. We can maybe make an exception for input tensors, but for any intermediate the runtime needs to know what the device its loading intermediates onto.

Device is fixed at export aot. If you want to have some generic shader style lib where the gpu type is decided lazily then you will have to use a generic key like gpu.

if result is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it not return cpu by default

continue

target_device_type, device_index = result

# Tag delegate input tensors.
# args[0] is the get_attr node for the lowered module; skip it.
for arg in node.args[1:]:
if isinstance(arg, torch.fx.Node):
changed |= _tag_specs_with_device(
arg.meta.get("spec"),
target_device_type,
device_index,
)

# Tag delegate output tensors.
changed |= _tag_specs_with_device(
node.meta.get("spec"),
target_device_type,
device_index,
)

logger.debug(
"PropagateDevicePass: set device=%s on delegate node %s "
"(backend=%s)",
target_device_type,
node.name,
lowered_module.backend_id,
)

# Second pass: propagate device through getitem nodes that extract
# individual outputs from a delegate call.
for node in graph_module.graph.nodes:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just do 1 pass. You can look at users of the delegate node to find the getitem nodes.

if node.op == "call_function" and node.target.__name__ == "getitem":
source_node = node.args[0]
if (
isinstance(source_node, torch.fx.Node)
and source_node.op == "call_function"
and source_node.target == executorch_call_delegate
):
spec = node.meta.get("spec")
source_specs = source_node.meta.get("spec")
idx = node.args[1]
if (
spec is not None
and isinstance(spec, TensorSpec)
and source_specs is not None
and isinstance(source_specs, (tuple, list))
and isinstance(idx, int)
and idx < len(source_specs)
):
source_spec = source_specs[idx]
if isinstance(source_spec, TensorSpec):
_set_device_on_spec(
spec,
source_spec.device,
source_spec.device_index,
)
changed = True

return PassResult(graph_module, changed)
2 changes: 2 additions & 0 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
"mem_offset",
"dtype", # property
"extra_tensor_info", # property
"device",
"device_index",
]

# Make sure _self_fields and _base_fields are disjoint
Expand Down
1 change: 1 addition & 0 deletions exir/program/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_aten_with_edge_pass",
Expand Down
2 changes: 2 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
from executorch.exir.passes.reinplace import reinplace_pass
from executorch.exir.passes.remove_graph_asserts_pass import (
Expand Down Expand Up @@ -848,6 +849,7 @@ def edge_to_executorch_passes(
# there exists an unbacked symint operation.
*config.passes,
SpecPropPass(),
PropagateDevicePass(),
EdgeToBackendOpsPass(),
RemoveGraphAssertsPass(),
] + pre_memory_planning_passes(config, name)
Expand Down
4 changes: 4 additions & 0 deletions exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def __init__(
self.init_mem_planning_fields()
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
self.extra_tensor_info = extra_tensor_info
# device type will be only updated during PropagateDevicePass.
self.device: schema.DeviceType = schema.DeviceType.CPU
self.device_index: int = 0

@property
def allocated_memory(self) -> int:
Expand Down Expand Up @@ -254,6 +257,7 @@ def __repr__(self) -> str:
+ f", is_sparse={self.is_sparse}"
+ f", shape_dynamism={self.shape_dynamism}"
+ f", const={self.const}, requires_grad={self.requires_grad}"
+ f", device={self.device.name}:{self.device_index}"
+ ")"
)

Expand Down
20 changes: 20 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,23 @@ python_unittest(
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
],
)

python_unittest(
name = "propagate_device_pass",
srcs = [
"test_propagate_device_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:schema",
"//executorch/exir:tensor",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:propagate_device_pass",
],
)
Loading
Loading