diff --git a/exir/passes/BUCK b/exir/passes/BUCK index a743aecde0b..954f1cfdb4f 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -451,3 +451,17 @@ fbcode_target(_kind = runtime.python_library, "//caffe2:torch", ], ) + +fbcode_target(_kind = runtime.python_library, + name = "propagate_device_pass", + srcs = [ + "propagate_device_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:delegate", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:schema", + "//executorch/exir:tensor", + ], +) diff --git a/exir/passes/propagate_device_pass.py b/exir/passes/propagate_device_pass.py new file mode 100644 index 00000000000..c36e10c5f56 --- /dev/null +++ b/exir/passes/propagate_device_pass.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Optional + +import executorch.exir.schema as schema + +import torch +from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.lowered_backend_module import LoweredBackendModule +from executorch.exir.tensor import TensorSpec +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +logger: logging.Logger = logging.getLogger(__name__) + +# CompileSpec key convention for specifying the target device. +# Partitioners that target a specific device should include a CompileSpec entry +# with this key and a value encoding the device string (e.g., b"cuda:0"). +TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device" + + +def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]: + """ + Parse a target_device CompileSpec value (e.g., b"cuda:0") into + (DeviceType, device_index). + + The type portion is matched case-insensitively against schema.DeviceType + member names (e.g., "cpu", "cuda"). Raises ValueError for unknown types. + """ + device_str = value.decode("utf-8").strip().lower() + if ":" in device_str: + type_str, index_str = device_str.split(":", 1) + device_index = int(index_str) + else: + type_str = device_str + device_index = 0 + device_type = next( + (dt for dt in schema.DeviceType if dt.name.lower() == type_str), + None, + ) + if device_type is None: + valid = ", ".join(dt.name for dt in schema.DeviceType) + raise ValueError(f"Unknown device type '{type_str}'. Valid types: {valid}") + return device_type, device_index + + +def _get_lowered_module( + graph_module: torch.fx.GraphModule, + delegate_call_node: torch.fx.Node, +) -> Optional[LoweredBackendModule]: + """ + Given an executorch_call_delegate node, retrieve the associated + LoweredBackendModule from the graph module. + The first argument to executorch_call_delegate is a get_attr node + whose target names the LoweredBackendModule attribute. + """ + if len(delegate_call_node.args) < 1: + return None + lowered_node = delegate_call_node.args[0] + if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr": + return None + lowered_module = getattr(graph_module, lowered_node.target, None) + if isinstance(lowered_module, LoweredBackendModule): + return lowered_module + return None + + +def _get_target_device_from_compile_specs( + lowered_module: LoweredBackendModule, +) -> Optional[tuple[schema.DeviceType, int]]: + """ + Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return + the corresponding (DeviceType, device_index), or None if not found. + """ + for spec in lowered_module.compile_specs: + if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY: + return _parse_device_spec_value(spec.value) + return None + + +def _set_device_on_spec( + spec: TensorSpec, + device_type: schema.DeviceType, + device_index: int = 0, +) -> None: + """Set the device attribute on a TensorSpec.""" + spec.device = device_type + spec.device_index = device_index + + +def _tag_specs_with_device( + specs: object, + device_type: schema.DeviceType, + device_index: int = 0, +) -> bool: + """Apply device annotation to a TensorSpec or a collection of TensorSpecs. + + Args: + specs: A TensorSpec, a tuple/list of TensorSpecs, or None. + device_type: The target device type to set. + device_index: The device index (e.g., 0 for cuda:0, 1 for cuda:1). + + Returns: + True if any spec was modified, False otherwise. + """ + if specs is None: + return False + if isinstance(specs, TensorSpec): + _set_device_on_spec(specs, device_type, device_index) + return True + if isinstance(specs, (tuple, list)): + changed = False + for s in specs: + if isinstance(s, TensorSpec): + _set_device_on_spec(s, device_type, device_index) + changed = True + return changed + return False + + +class PropagateDevicePass(PassBase): + """ + After to_backend, walk the graph and set device metadata on TensorSpecs + based on partitioner-assigned delegation info. + + Rules: + 1. Delegated nodes: Input and output tensors of a delegate call are marked + with the target device derived from the delegate's CompileSpec + (key="target_device"). + 2. Non-delegated nodes: Remain on CPU (default). + 3. Getitem nodes that extract from a delegate call inherit the device from + the delegate call's output spec at the corresponding index. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + changed = False + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == executorch_call_delegate: + lowered_module = _get_lowered_module(graph_module, node) + if lowered_module is None: + raise RuntimeError( + f"executorch_call_delegate node '{node.name}' does not reference " + "a valid LoweredBackendModule. The first argument must be a " + "get_attr node pointing to a LoweredBackendModule attribute." + ) + + result = _get_target_device_from_compile_specs(lowered_module) + if result is None: + continue + + target_device_type, device_index = result + + # Tag delegate input tensors. + # args[0] is the get_attr node for the lowered module; skip it. + for arg in node.args[1:]: + if isinstance(arg, torch.fx.Node): + changed |= _tag_specs_with_device( + arg.meta.get("spec"), + target_device_type, + device_index, + ) + + # Tag delegate output tensors. + changed |= _tag_specs_with_device( + node.meta.get("spec"), + target_device_type, + device_index, + ) + + logger.debug( + "PropagateDevicePass: set device=%s on delegate node %s " + "(backend=%s)", + target_device_type, + node.name, + lowered_module.backend_id, + ) + + # Second pass: propagate device through getitem nodes that extract + # individual outputs from a delegate call. + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target.__name__ == "getitem": + source_node = node.args[0] + if ( + isinstance(source_node, torch.fx.Node) + and source_node.op == "call_function" + and source_node.target == executorch_call_delegate + ): + spec = node.meta.get("spec") + source_specs = source_node.meta.get("spec") + idx = node.args[1] + if ( + spec is not None + and isinstance(spec, TensorSpec) + and source_specs is not None + and isinstance(source_specs, (tuple, list)) + and isinstance(idx, int) + and idx < len(source_specs) + ): + source_spec = source_specs[idx] + if isinstance(source_spec, TensorSpec): + _set_device_on_spec( + spec, + source_spec.device, + source_spec.device_index, + ) + changed = True + + return PassResult(graph_module, changed) diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index b19cfbed95d..28fcc97aaf5 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -110,6 +110,8 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: "mem_offset", "dtype", # property "extra_tensor_info", # property + "device", + "device_index", ] # Make sure _self_fields and _base_fields are disjoint diff --git a/exir/program/BUCK b/exir/program/BUCK index 221e27c3087..7d9642efdb7 100644 --- a/exir/program/BUCK +++ b/exir/program/BUCK @@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/exir/passes:insert_write_back_for_buffers_pass", "//executorch/exir/passes:lib", "//executorch/exir/passes:normalize_view_copy_base_pass", + "//executorch/exir/passes:propagate_device_pass", "//executorch/exir/passes:remove_graph_asserts_pass", "//executorch/exir/passes:remove_mixed_type_operators", "//executorch/exir/passes:replace_aten_with_edge_pass", diff --git a/exir/program/_program.py b/exir/program/_program.py index baacd5eaec4..c68d0eed945 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -59,6 +59,7 @@ from executorch.exir.passes.normalize_view_copy_base_pass import ( NormalizeViewCopyBasePass, ) +from executorch.exir.passes.propagate_device_pass import PropagateDevicePass from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass from executorch.exir.passes.reinplace import reinplace_pass from executorch.exir.passes.remove_graph_asserts_pass import ( @@ -848,6 +849,7 @@ def edge_to_executorch_passes( # there exists an unbacked symint operation. *config.passes, SpecPropPass(), + PropagateDevicePass(), EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), ] + pre_memory_planning_passes(config, name) diff --git a/exir/tensor.py b/exir/tensor.py index b80a637ea96..79f8fff4abc 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -172,6 +172,9 @@ def __init__( self.init_mem_planning_fields() self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape) self.extra_tensor_info = extra_tensor_info + # device type will be only updated during PropagateDevicePass. + self.device: schema.DeviceType = schema.DeviceType.CPU + self.device_index: int = 0 @property def allocated_memory(self) -> int: @@ -254,6 +257,7 @@ def __repr__(self) -> str: + f", is_sparse={self.is_sparse}" + f", shape_dynamism={self.shape_dynamism}" + f", const={self.const}, requires_grad={self.requires_grad}" + + f", device={self.device.name}:{self.device_index}" + ")" ) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index c9136ce51da..322f72c870a 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -484,3 +484,23 @@ python_unittest( "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", ], ) + +python_unittest( + name = "propagate_device_pass", + srcs = [ + "test_propagate_device_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:schema", + "//executorch/exir:tensor", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + "//executorch/exir/backend/test:backend_with_compiler_demo", + "//executorch/exir/dialects:lib", + "//executorch/exir/passes:propagate_device_pass", + ], +) diff --git a/exir/tests/test_propagate_device_pass.py b/exir/tests/test_propagate_device_pass.py new file mode 100644 index 00000000000..26249991be9 --- /dev/null +++ b/exir/tests/test_propagate_device_pass.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +import unittest +from copy import deepcopy +from typing import Dict, final, List + +import torch +from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_pattern_op_partitions, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.test.backend_with_compiler_demo import ( + BackendWithCompilerDemo, +) +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.propagate_device_pass import ( + _get_target_device_from_compile_specs, + _parse_device_spec_value, + TARGET_DEVICE_COMPILE_SPEC_KEY, +) +from executorch.exir.schema import DeviceType +from executorch.exir.tensor import TensorSpec +from torch.export import export +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase + + +class AddOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor, + ] + + +@final +class DeviceAwarePartitioner(Partitioner): + def __init__(self, target_device: str = "cuda:0") -> None: + super().__init__() + self.op_support = any_chain(AddOperatorSupport()) + self.delegation_spec = DelegationSpec( + BackendWithCompilerDemo.__name__, + [ + CompileSpec("max_value", bytes([4])), + CompileSpec( + TARGET_DEVICE_COMPILE_SPEC_KEY, + target_device.encode("utf-8"), + ), + ], + ) + + def partition(self, exported_program) -> PartitionResult: + partition_tags: Dict[str, DelegationSpec] = {} + partition_list = generate_pattern_op_partitions( + exported_program.graph_module, op_support=self.op_support + ) + for partition in partition_list: + for node in partition.nodes: + delegation_tag = f"tag{partition.id}" + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + return PartitionResult( + tagged_exported_program=exported_program, + partition_tags=partition_tags, + ) + + +@final +class CpuOnlyPartitioner(Partitioner): + def __init__(self) -> None: + super().__init__() + self.op_support = any_chain(AddOperatorSupport()) + self.delegation_spec = DelegationSpec( + BackendWithCompilerDemo.__name__, + [CompileSpec("max_value", bytes([4]))], + ) + + def partition(self, exported_program) -> PartitionResult: + partition_tags: Dict[str, DelegationSpec] = {} + partition_list = generate_pattern_op_partitions( + exported_program.graph_module, op_support=self.op_support + ) + for partition in partition_list: + for node in partition.nodes: + delegation_tag = f"tag{partition.id}" + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + return PartitionResult( + tagged_exported_program=exported_program, + partition_tags=partition_tags, + ) + + +def _lower_model_to_executorch( + model: torch.nn.Module, + inputs: tuple, + partitioner: Partitioner, +) -> List: + """Lower model all the way through to_executorch for E2E tests.""" + ep = export(model, inputs) + ep_copied = deepcopy(ep) + + edge_1 = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) + lowered_1 = edge_1.to_backend(partitioner) + et_1 = lowered_1.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + gm_1 = et_1.exported_program().graph_module + + edge_2 = to_edge_transform_and_lower(ep_copied, partitioner=[partitioner]) + et_2 = edge_2.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + gm_2 = et_2.exported_program().graph_module + + return [ + ("to_edge+to_backend", gm_1), + ("to_edge_transform_and_lower", gm_2), + ] + + +class TestPropagateDevicePass(unittest.TestCase): + @staticmethod + def _collect_tensor_specs(node: torch.fx.Node) -> List[TensorSpec]: + """Return a flat list of TensorSpecs from a node's 'spec' metadata.""" + spec = node.meta.get("spec") + if spec is None: + return [] + if isinstance(spec, TensorSpec): + return [spec] + if isinstance(spec, (tuple, list)): + return [s for s in spec if isinstance(s, TensorSpec)] + return [] + + @staticmethod + def _is_delegate_getitem(node: torch.fx.Node) -> bool: + """Return True if *node* is a getitem extracting from a delegate call.""" + if node.target != operator.getitem: + return False + source = node.args[0] + return ( + isinstance(source, torch.fx.Node) + and source.op == "call_function" + and source.target == executorch_call_delegate + ) + + def _assert_specs_device( + self, + specs: List[TensorSpec], + expected_device: DeviceType, + msg: str, + expected_index: int | None = None, + ) -> None: + """Assert every spec has the expected device (and optionally index).""" + for s in specs: + self.assertEqual(s.device, expected_device, msg) + if expected_index is not None: + self.assertEqual(s.device_index, expected_index) + + def test_device_consistency_cuda_1(self): + """Verify device tags are correct with cuda:1 after to_executorch() + to verify device_index propagation through the full pipeline.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, gm in _lower_model_to_executorch( + model, inputs, DeviceAwarePartitioner("cuda:1") + ): + with self.subTest(pipeline=pipeline): + for node in gm.graph.nodes: + if node.op != "call_function": + continue + specs = self._collect_tensor_specs(node) + if not specs: + continue + + label = f"[{pipeline}] '{node.name}'" + if node.target == executorch_call_delegate: + self._assert_specs_device( + specs, + DeviceType.CUDA, + f"{label} Delegate should be CUDA", + expected_index=1, + ) + elif self._is_delegate_getitem(node): + self._assert_specs_device( + specs, + DeviceType.CUDA, + f"{label} Delegate getitem should be CUDA", + expected_index=1, + ) + + def test_no_device_spec_remains_cpu(self): + """When partitioner has no target_device, all specs remain CPU + through the full to_executorch pipeline.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, gm in _lower_model_to_executorch( + model, inputs, CpuOnlyPartitioner() + ): + with self.subTest(pipeline=pipeline): + for node in gm.graph.nodes: + specs = self._collect_tensor_specs(node) + for s in specs: + self.assertEqual( + s.device, + DeviceType.CPU, + f"[{pipeline}] All specs should be CPU when no " + f"target_device, but node '{node.name}' is {s.device.name}", + ) + + def test_device_consistency_after_to_executorch(self): + """Verify device tags are correct in the final graph after + to_executorch(), not just after PropagateDevicePass alone. + Copy nodes should bridge CPU ↔ device at delegate boundaries.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, gm in _lower_model_to_executorch( + model, inputs, DeviceAwarePartitioner("cuda:0") + ): + with self.subTest(pipeline=pipeline): + for node in gm.graph.nodes: + if node.op != "call_function": + continue + specs = self._collect_tensor_specs(node) + if not specs: + continue + + label = f"[{pipeline}] '{node.name}'" + if node.target == executorch_call_delegate: + self._assert_specs_device( + specs, + DeviceType.CUDA, + f"{label} Delegate should be CUDA", + expected_index=0, + ) + elif self._is_delegate_getitem(node): + self._assert_specs_device( + specs, + DeviceType.CUDA, + f"{label} Delegate getitem should be CUDA", + expected_index=0, + ) + + # --- Unit tests for helper functions --- + + def test_parse_device_spec_value(self): + dt, idx = _parse_device_spec_value(b"cuda:0") + self.assertEqual(dt, DeviceType.CUDA) + self.assertEqual(idx, 0) + + dt, idx = _parse_device_spec_value(b"cuda:1") + self.assertEqual(dt, DeviceType.CUDA) + self.assertEqual(idx, 1) + + dt, idx = _parse_device_spec_value(b"cpu") + self.assertEqual(dt, DeviceType.CPU) + self.assertEqual(idx, 0) + + def test_parse_device_spec_value_unknown_raises(self): + with self.assertRaises(ValueError): + _parse_device_spec_value(b"tpu:0") + + def test_parse_device_spec_value_case_insensitive(self): + dt, idx = _parse_device_spec_value(b"CUDA:0") + self.assertEqual(dt, DeviceType.CUDA) + self.assertEqual(idx, 0) + + dt, idx = _parse_device_spec_value(b"Cuda:2") + self.assertEqual(dt, DeviceType.CUDA) + self.assertEqual(idx, 2) + + def test_get_target_device_from_compile_specs(self): + class MockLoweredModule: + __slots__ = ["compile_specs"] + + def __init__(self, specs): + self.compile_specs = specs + + module_with_cuda = MockLoweredModule( + [ + CompileSpec("max_value", bytes([4])), + CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"), + ] + ) + result = _get_target_device_from_compile_specs(module_with_cuda) + self.assertIsNotNone(result) + dt, idx = result + self.assertEqual(dt, DeviceType.CUDA) + self.assertEqual(idx, 0) + + module_without_device = MockLoweredModule( + [ + CompileSpec("max_value", bytes([4])), + ] + ) + result = _get_target_device_from_compile_specs(module_without_device) + self.assertIsNone(result) + + # ---- End-to-end tests: verify device info survives to_executorch ---- + + def _get_executorch_program(self, model, inputs, partitioner): + """Run the full pipeline and return (emitted_program, graph_module) pairs + for both export pipelines.""" + from executorch.exir.capture._config import ExecutorchBackendConfig + + ep = export(model, inputs) + ep_copied = deepcopy(ep) + + # Pipeline 1: to_edge → to_backend → to_executorch + edge_1 = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) + lowered_1 = edge_1.to_backend(partitioner) + et_1 = lowered_1.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + program_1 = et_1._emitter_output.program + gm_1 = et_1.exported_program().graph_module + + # Pipeline 2: to_edge_transform_and_lower → to_executorch + edge_2 = to_edge_transform_and_lower(ep_copied, partitioner=[partitioner]) + et_2 = edge_2.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False)) + program_2 = et_2._emitter_output.program + gm_2 = et_2.exported_program().graph_module + + return [ + ("to_edge+to_backend", program_1, gm_1), + ("to_edge_transform_and_lower", program_2, gm_2), + ] + + def test_e2e_device_on_specs_after_to_executorch(self): + """ + After the full to_executorch pipeline, delegate output TensorSpecs + should still have device == CUDA on the graph module nodes. + """ + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, _program, gm in self._get_executorch_program( + model, inputs, DeviceAwarePartitioner("cuda:0") + ): + with self.subTest(pipeline=pipeline): + found_delegate = False + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == executorch_call_delegate + ): + found_delegate = True + specs = node.meta.get("spec") + self.assertIsNotNone(specs) + if isinstance(specs, TensorSpec): + self.assertEqual( + specs.device, + DeviceType.CUDA, + f"[{pipeline}] spec.device should be CUDA after to_executorch", + ) + elif isinstance(specs, (tuple, list)): + for s in specs: + if isinstance(s, TensorSpec): + self.assertEqual( + s.device, + DeviceType.CUDA, + f"[{pipeline}] spec.device should be CUDA after to_executorch", + ) + + self.assertTrue(found_delegate) + + def test_e2e_non_delegated_tensor_specs_remain_cpu(self): + """ + After to_executorch, non-delegated node specs should still be CPU. + Getitem nodes extracting from a delegate call are considered delegated. + """ + + class Model(torch.nn.Module): + def forward(self, a, b): + c = torch.add(a, b) + d = torch.sin(c) + return d + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, _program, gm in self._get_executorch_program( + model, inputs, DeviceAwarePartitioner("cuda:0") + ): + with self.subTest(pipeline=pipeline): + for node in gm.graph.nodes: + if node.op != "call_function": + continue + # Skip delegate call nodes + if node.target == executorch_call_delegate: + continue + # Skip getitem nodes that extract from a delegate call + if node.target == operator.getitem: + source = node.args[0] + if ( + isinstance(source, torch.fx.Node) + and source.op == "call_function" + and source.target == executorch_call_delegate + ): + continue + + def test_tensorspec_repr_includes_device(self): + spec = TensorSpec(dtype=torch.float32, shape=torch.Size([2, 3])) + repr_str = repr(spec) + self.assertIn("device=", repr_str) + self.assertIn("CPU", repr_str) + + +if __name__ == "__main__": + unittest.main()