From 45dee1a5d57367dd778d6269e744437f4aaedc11 Mon Sep 17 00:00:00 2001 From: Tristen Pierson Date: Tue, 2 Jun 2026 10:25:39 -0400 Subject: [PATCH] feat: add Python evaluator and CLI eval command (REQ-ARCH-035) Implement pure-Python eval loop mirroring the C engine: - ArbiterEvaluator class with all 13 condition operators - All 15 expression opcodes with int32 saturation - Safety guard priority ordering - arbiterc eval CLI command with --facts, --timestamps, --json - 55+ comprehensive test cases Co-Authored-By: Oz --- python/arbiter/cli.py | 71 +++ python/arbiter/evaluator.py | 505 ++++++++++++++++++ tests/python/test_evaluator.py | 906 +++++++++++++++++++++++++++++++++ 3 files changed, 1482 insertions(+) create mode 100644 python/arbiter/evaluator.py create mode 100644 tests/python/test_evaluator.py diff --git a/python/arbiter/cli.py b/python/arbiter/cli.py index adbcd6d..483c788 100644 --- a/python/arbiter/cli.py +++ b/python/arbiter/cli.py @@ -11,6 +11,7 @@ from . import __version__ from .compiler import CompileOptions, compile_model from .diagnostics import DiagnosticCollector +from .evaluator import ArbiterEvaluator from .parser import parse_model from .schema import validate_schema from .validator import validate_model @@ -134,6 +135,76 @@ def emit_docs_cmd(model: Path, out: Path) -> None: click.echo(f"✓ Documentation written to {out}") +@main.command() +@click.argument("model", type=click.Path(exists=True, path_type=Path)) +@click.option( + "--facts", + multiple=True, + help="Set facts as key=value pairs (e.g. battery.voltage_mv=3300).", +) +@click.option( + "--timestamps", + multiple=True, + help="Set fact timestamps as key=ms pairs (e.g. battery.voltage_mv=100).", +) +@click.option( + "--snapshot-ts", + type=int, + default=0, + help="Snapshot timestamp in ms (for staleness checks).", +) +@click.option("--json", "emit_json", is_flag=True, help="Output result as JSON.") +def eval(model: Path, facts: tuple[str, ...], timestamps: tuple[str, ...], + snapshot_ts: int, emit_json: bool) -> None: + """Evaluate a .arb.yaml model with given facts.""" + import json as json_mod + + diag = DiagnosticCollector() + data = parse_model(model, diag) + if data is None: + click.echo(diag.format(), err=True) + sys.exit(1) + + evaluator = ArbiterEvaluator(data) + + for kv in facts: + if "=" not in kv: + click.echo(f"Error: invalid fact '{kv}', expected key=value", err=True) + sys.exit(1) + key, val = kv.split("=", 1) + try: + evaluator.set_fact(key, val) + except KeyError as e: + click.echo(f"Error: {e}", err=True) + sys.exit(1) + + for kv in timestamps: + if "=" not in kv: + click.echo(f"Error: invalid timestamp '{kv}', expected key=ms", err=True) + sys.exit(1) + key, val = kv.split("=", 1) + try: + evaluator.set_timestamp(key, int(val)) + except (KeyError, ValueError) as e: + click.echo(f"Error: {e}", err=True) + sys.exit(1) + + evaluator.set_snapshot_timestamp(snapshot_ts) + result = evaluator.eval() + + if emit_json: + click.echo(json_mod.dumps(result.to_dict(), indent=2)) + else: + click.echo(f"Fired rules: {result.fired_rules}") + if result.current_mode: + click.echo(f"Mode: {result.current_mode}") + if result.requested_actions: + click.echo(f"Actions: {result.requested_actions}") + if result.raised_faults: + click.echo(f"Faults: {sorted(result.raised_faults)}") + click.echo(f"Op count: {result.op_count}") + + @main.command("emit-tests") @click.argument("model", type=click.Path(exists=True, path_type=Path)) @click.option("--out", type=click.Path(path_type=Path), required=True) diff --git a/python/arbiter/evaluator.py b/python/arbiter/evaluator.py new file mode 100644 index 0000000..81a6770 --- /dev/null +++ b/python/arbiter/evaluator.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: MIT +"""Pure-Python evaluator that mirrors the C arbiter engine exactly. + +This module provides a Python implementation of the deterministic eval loop +so that models can be exercised and tested without compiling C or flashing +a Zephyr target. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from .canonical import CanonicalModel, canonicalize + +# --------------------------------------------------------------------------- +# Constants — mirror the C engine +# --------------------------------------------------------------------------- + +INT32_MAX = 2_147_483_647 +INT32_MIN = -2_147_483_648 +UINT16_MAX = 65535 # sentinel for "use literal" + +# Rule class evaluation order — safety_guard runs first. +_RULE_CLASS_ORDER = { + "safety_guard": 0, + "obligation": 1, + "constraint": 2, + "mode_guard": 3, + "inference": 4, + "advisory": 5, +} + +# Action types that map to fault operations. +_ACTION_TYPE_MAP = { + "callback": "callback", + "log": "log", + "notify": "notify", + "set_fact": "set_fact", + "set_mode": "set_mode", + "raise_fault": "raise_fault", + "clear_fault": "clear_fault", +} + + +# --------------------------------------------------------------------------- +# Result types +# --------------------------------------------------------------------------- + + +@dataclass +class TraceEntry: + """One rule evaluation trace record.""" + + rule_id: str + fired: bool + reason: str = "" + + +@dataclass +class EvalResult: + """Output of a single evaluation pass.""" + + fired_rules: list[str] = field(default_factory=list) + current_mode: str | None = None + raised_faults: set[str] = field(default_factory=set) + requested_actions: list[str] = field(default_factory=list) + op_count: int = 0 + trace: list[TraceEntry] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Serialise to a plain dict (for JSON output).""" + return { + "fired_rules": self.fired_rules, + "current_mode": self.current_mode, + "raised_faults": sorted(self.raised_faults), + "requested_actions": self.requested_actions, + "op_count": self.op_count, + "trace": [ + {"rule_id": t.rule_id, "fired": t.fired, "reason": t.reason} + for t in self.trace + ], + } + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _saturate32(value: int) -> int: + """Clamp an arbitrary-precision int to the int32 range.""" + if value > INT32_MAX: + return INT32_MAX + if value < INT32_MIN: + return INT32_MIN + return value + + +# --------------------------------------------------------------------------- +# Evaluator +# --------------------------------------------------------------------------- + + +class ArbiterEvaluator: + """Pure-Python evaluator for a canonicalised ARB model. + + Usage:: + + ev = ArbiterEvaluator(model_data) + ev.set_fact("battery.voltage_mv", 3400) + result = ev.eval() + """ + + def __init__(self, model_data: dict[str, Any]) -> None: + self._raw = model_data + self._model: CanonicalModel = canonicalize(model_data) + + # Fact state: name -> int32 value. Initialised to defaults. + self._fact_values: dict[str, int] = {} + self._fact_prev: dict[str, int] = {} + self._fact_timestamps: dict[str, int] = {} + self._fact_valid: dict[str, bool] = {} + + for f in self._model.facts: + name = f["id"] + default = int(f.get("default", 0)) + if f.get("type") == "bool": + default = int(bool(f.get("default", False))) + self._fact_values[name] = default + self._fact_prev[name] = default + self._fact_timestamps[name] = 0 + self._fact_valid[name] = False + + # Snapshot timestamp (set by caller for staleness tests). + self._snapshot_ts: int = 0 + + # Faults (persistent across evals until cleared). + self._raised_faults: set[str] = set() + + # ---- public API ------------------------------------------------------- + + def set_fact(self, name: str, value: Any) -> None: + """Set a fact value (bool → 0/1, int/str).""" + if name not in self._fact_values: + raise KeyError(f"Unknown fact: {name}") + if isinstance(value, bool): + self._fact_values[name] = int(value) + elif isinstance(value, str): + # CLI may pass "true" / "false" + if value.lower() == "true": + self._fact_values[name] = 1 + elif value.lower() == "false": + self._fact_values[name] = 0 + else: + self._fact_values[name] = int(value) + else: + self._fact_values[name] = int(value) + self._fact_valid[name] = True + + def set_timestamp(self, name: str, ms: int) -> None: + """Set the timestamp for a fact (for staleness detection).""" + if name not in self._fact_values: + raise KeyError(f"Unknown fact: {name}") + self._fact_timestamps[name] = int(ms) + self._fact_valid[name] = True + + def set_snapshot_timestamp(self, ms: int) -> None: + """Set the global snapshot (eval) timestamp.""" + self._snapshot_ts = int(ms) + + def eval(self) -> EvalResult: + """Run one evaluation pass. Returns an :class:`EvalResult`.""" + result = EvalResult() + op_count = 0 + + # Snapshot prev values for "changed" / "delta" operators. + prev_snapshot: dict[str, int] = dict(self._fact_prev) + + # Sort rules: safety_guard first, then by canonical order (alpha by id). + ordered_rules = sorted( + self._model.rules, + key=lambda r: ( + _RULE_CLASS_ORDER.get(r.get("class", "inference"), 4), + r.get("id", ""), + ), + ) + + for rule in ordered_rules: + rule_id = rule["id"] + when = rule.get("when", {}) + then = rule.get("then", {}) + if not isinstance(then, dict): + then = {} + + fired = self._eval_condition_block(when, prev_snapshot) + op_count += 1 # condition evaluation counts as 1 op + + reason = then.get("explanation", "") + + result.trace.append(TraceEntry( + rule_id=rule_id, + fired=fired, + reason=reason if fired else "", + )) + + if not fired: + continue + + result.fired_rules.append(rule_id) + + # --- Mode transition --- + mode_target = then.get("set_mode") + if mode_target: + result.current_mode = mode_target + + # --- Compute expressions --- + expr_start = rule.get("_expr_start", 0) + expr_count = rule.get("_expr_count", 0) + for i in range(expr_start, expr_start + expr_count): + if i < len(self._model.expressions): + self._exec_expression(self._model.expressions[i]) + op_count += 1 + + # --- Action --- + action_ref = then.get("action") + if action_ref: + result.requested_actions.append(action_ref) + # Check if action is raise_fault or clear_fault + action_def = self._find_action(action_ref) + if action_def: + atype = action_def.get("type", "callback") + if atype == "raise_fault": + self._raised_faults.add(action_ref) + elif atype == "clear_fault": + self._raised_faults.discard(action_ref) + + # --- Inline raise_fault / clear_fault in then block --- + if then.get("raise_fault"): + fault_id = then["raise_fault"] + self._raised_faults.add(fault_id) + result.requested_actions.append(f"raise_fault:{fault_id}") + if then.get("clear_fault"): + fault_id = then["clear_fault"] + self._raised_faults.discard(fault_id) + result.requested_actions.append(f"clear_fault:{fault_id}") + + # Save prev values for next eval cycle. + self._fact_prev = dict(self._fact_values) + + result.raised_faults = set(self._raised_faults) + result.op_count = op_count + return result + + # ---- condition evaluation --------------------------------------------- + + def _eval_condition_block( + self, + when: dict[str, Any], + prev_snapshot: dict[str, int], + ) -> bool: + """Evaluate a top-level condition block (may have all/any/not groups).""" + if not isinstance(when, dict) or not when: + # Empty condition block → always true (unconditional rule). + return True + + # Process each group type present. Multiple groups are AND-ed. + group_results: list[bool] = [] + + for group_type in ("all", "any", "not"): + group = when.get(group_type) + if group is None: + continue + if not isinstance(group, list): + group = [group] + + result = self._eval_group(group_type, group, prev_snapshot) + group_results.append(result) + + if not group_results: + return True + return all(group_results) + + def _eval_group( + self, + group_type: str, + conditions: list[Any], + prev_snapshot: dict[str, int], + ) -> bool: + """Evaluate a condition group (ALL / ANY / NOT).""" + if group_type == "all": + for cond in conditions: + if not isinstance(cond, dict): + continue + if not self._eval_single_condition(cond, prev_snapshot): + return False # short-circuit + return True + + if group_type == "any": + for cond in conditions: + if not isinstance(cond, dict): + continue + if self._eval_single_condition(cond, prev_snapshot): + return True # short-circuit + return False + + if group_type == "not": + # NOT inverts: true if ALL inner conditions are false. + for cond in conditions: + if not isinstance(cond, dict): + continue + if self._eval_single_condition(cond, prev_snapshot): + return False # one was true → NOT fails + return True + + return True # unknown group → vacuously true + + def _eval_single_condition( + self, + cond: dict[str, Any], + prev_snapshot: dict[str, int], + ) -> bool: + """Evaluate one condition (fact op value).""" + fact_name = cond.get("fact", "") + op = cond.get("op", "==") + threshold = cond.get("value", 0) + + fact_val = self._fact_values.get(fact_name, 0) + + if op == "==": + return self._coerce_eq(fact_val, threshold) + if op == "!=": + return not self._coerce_eq(fact_val, threshold) + if op == "<": + return fact_val < int(threshold) + if op == "<=": + return fact_val <= int(threshold) + if op == ">": + return fact_val > int(threshold) + if op == ">=": + return fact_val >= int(threshold) + if op == "in": + if isinstance(threshold, list): + return fact_val in [int(v) for v in threshold] + return fact_val == int(threshold) + if op == "not_in": + if isinstance(threshold, list): + return fact_val not in [int(v) for v in threshold] + return fact_val != int(threshold) + if op == "stale": + return self._check_stale(fact_name, int(threshold)) + if op == "not_stale": + return not self._check_stale(fact_name, int(threshold)) + if op == "changed": + prev = prev_snapshot.get(fact_name, 0) + return fact_val != prev + if op == "delta_gt": + prev = prev_snapshot.get(fact_name, 0) + delta = abs(fact_val - prev) + return delta > int(threshold) + if op == "delta_lt": + prev = prev_snapshot.get(fact_name, 0) + delta = abs(fact_val - prev) + return delta < int(threshold) + + return False # unknown operator + + @staticmethod + def _coerce_eq(fact_val: int, threshold: Any) -> bool: + """Equality with bool coercion: true/false → 1/0.""" + if isinstance(threshold, bool): + return fact_val == int(threshold) + return fact_val == int(threshold) + + def _check_stale(self, fact_name: str, threshold_ms: int) -> bool: + """Return True if the fact's timestamp is stale w.r.t. snapshot time.""" + ts = self._fact_timestamps.get(fact_name, 0) + if ts == 0 and not self._fact_valid.get(fact_name, False): + # Never written → stale. + return True + age = self._snapshot_ts - ts + return age > threshold_ms + + # ---- expression execution --------------------------------------------- + + def _exec_expression(self, expr: dict[str, Any]) -> None: + """Execute one compute expression, writing the result to a fact.""" + target_id = expr.get("target_fact_id", 0) + target_name = self._fact_name_by_index(target_id) + if target_name is None: + return + + op = expr.get("op", "assign") + left = self._resolve_operand( + expr.get("left_fact_id", UINT16_MAX), + expr.get("left_literal", 0), + ) + right = self._resolve_operand( + expr.get("right_fact_id", UINT16_MAX), + expr.get("right_literal", 0), + ) + scale = expr.get("scale", 1) + + result = self._compute_op(op, target_name, left, right, scale) + self._fact_values[target_name] = result + + def _compute_op( + self, + op: str, + target_name: str, + left: int, + right: int, + scale: int, + ) -> int: + """Compute one expression opcode. Returns saturated int32 result.""" + if op == "assign": + return _saturate32(left) + + if op == "add": + return _saturate32(left + right) + + if op == "sub": + return _saturate32(left - right) + + if op == "mul": + return _saturate32(left * right) + + if op == "div": + if right == 0: + return 0 + # Truncate toward zero (C behaviour). + return _saturate32(int(left / right)) + + if op == "mod": + if right == 0: + return 0 + # Python mod differs from C for negative numbers. + # C99: result has the sign of the dividend. + if left == 0: + return 0 + result = abs(left) % abs(right) + return _saturate32(-result if left < 0 else result) + + if op == "abs": + return _saturate32(abs(left)) + + if op == "negate": + return _saturate32(-left) + + if op == "min": + return _saturate32(min(left, right)) + + if op == "max": + return _saturate32(max(left, right)) + + if op == "clamp": + # clamp(left, lo=right, hi=scale) + lo = right + hi = scale + return _saturate32(max(lo, min(left, hi))) + + if op == "shift_r": + return _saturate32(left >> right) + + if op == "shift_l": + return _saturate32(left << right) + + if op == "scale": + # target = (left * right) / scale (int64 widening) + if scale == 0: + return 0 + wide: int = left * right # Python int is arbitrary precision + return _saturate32(int(wide / scale)) + + if op == "accumulate": + # target = target + (left * right) / scale (int64 widening) + current = self._fact_values.get(target_name, 0) + if scale == 0: + return _saturate32(current) + wide = left * right + return _saturate32(current + int(wide / scale)) + + return 0 # unknown op + + def _resolve_operand(self, fact_id: int, literal: int) -> int: + """Resolve an operand: use fact value if fact_id != UINT16_MAX, else literal.""" + if fact_id == UINT16_MAX: + return int(literal) + name = self._fact_name_by_index(fact_id) + if name is None: + return int(literal) + return self._fact_values.get(name, 0) + + def _fact_name_by_index(self, idx: int) -> str | None: + """Map a canonical fact index back to its name.""" + if 0 <= idx < len(self._model.facts): + return self._model.facts[idx]["id"] + return None + + def _find_action(self, action_id: str) -> dict[str, Any] | None: + """Look up an action definition by its id string.""" + for a in self._model.actions: + if a.get("id") == action_id: + return a + return None diff --git a/tests/python/test_evaluator.py b/tests/python/test_evaluator.py new file mode 100644 index 0000000..cf04ee5 --- /dev/null +++ b/tests/python/test_evaluator.py @@ -0,0 +1,906 @@ +# SPDX-License-Identifier: MIT +"""Comprehensive tests for the Python evaluator (REQ-ARCH-035).""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from arbiter.evaluator import ( + INT32_MAX, + INT32_MIN, + ArbiterEvaluator, + EvalResult, + TraceEntry, + _saturate32, +) + +SAMPLES_DIR = Path(__file__).resolve().parent.parent.parent / "samples" + + +# --------------------------------------------------------------------------- +# Helper: build a minimal model dict +# --------------------------------------------------------------------------- + + +def _model( + facts=None, + rules=None, + actions=None, + modes=None, + *, + name="test_model", +): + """Return a minimal valid ARB model dict.""" + return { + "arb_version": 0.1, + "model": name, + "target": {"rtos": "zephyr"}, + "facts": facts or [], + "rules": rules or [], + "actions": actions or [], + "modes": modes or [], + } + + +def _fact(fid, ftype="int32", **kwargs): + return {"id": fid, "type": ftype, **kwargs} + + +def _rule(rid, when=None, then=None, rclass="inference"): + r = {"id": rid, "class": rclass} + if when is not None: + r["when"] = when + if then is not None: + r["then"] = then + return r + + +# =================================================================== +# 1. BASIC EVALUATION +# =================================================================== + + +class TestBasicEval: + def test_no_rules(self): + """Empty model with no rules should return empty result.""" + m = _model(facts=[_fact("x")]) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert r.fired_rules == [] + assert r.current_mode is None + assert r.op_count == 0 + + def test_single_rule_fires(self): + """Rule fires when condition is met.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r1", when={"all": [{"fact": "x", "op": ">", "value": 10}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 20) + r = ev.eval() + assert r.fired_rules == ["r1"] + + def test_single_rule_does_not_fire(self): + """Rule does not fire when condition is not met.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r1", when={"all": [{"fact": "x", "op": ">", "value": 10}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 5) + r = ev.eval() + assert r.fired_rules == [] + + def test_unconditional_rule(self): + """Rule with no 'when' block always fires.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r1")], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert r.fired_rules == ["r1"] + + +# =================================================================== +# 2. CONDITION OPERATORS (all 13) +# =================================================================== + + +class TestConditionOperators: + def _eval_op(self, op, fact_val, threshold): + """Helper: evaluate a single condition with the given operator.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r", when={"all": [{"fact": "x", "op": op, "value": threshold}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", fact_val) + return ev.eval() + + def test_eq_true(self): + assert self._eval_op("==", 42, 42).fired_rules == ["r"] + + def test_eq_false(self): + assert self._eval_op("==", 42, 43).fired_rules == [] + + def test_eq_bool_true(self): + assert self._eval_op("==", 1, True).fired_rules == ["r"] + + def test_eq_bool_false(self): + assert self._eval_op("==", 0, True).fired_rules == [] + + def test_ne_true(self): + assert self._eval_op("!=", 42, 43).fired_rules == ["r"] + + def test_ne_false(self): + assert self._eval_op("!=", 42, 42).fired_rules == [] + + def test_lt_true(self): + assert self._eval_op("<", 5, 10).fired_rules == ["r"] + + def test_lt_false(self): + assert self._eval_op("<", 10, 5).fired_rules == [] + + def test_le_true_equal(self): + assert self._eval_op("<=", 10, 10).fired_rules == ["r"] + + def test_le_true_less(self): + assert self._eval_op("<=", 5, 10).fired_rules == ["r"] + + def test_le_false(self): + assert self._eval_op("<=", 11, 10).fired_rules == [] + + def test_gt_true(self): + assert self._eval_op(">", 10, 5).fired_rules == ["r"] + + def test_gt_false(self): + assert self._eval_op(">", 5, 10).fired_rules == [] + + def test_ge_true_equal(self): + assert self._eval_op(">=", 10, 10).fired_rules == ["r"] + + def test_ge_false(self): + assert self._eval_op(">=", 9, 10).fired_rules == [] + + def test_in_list(self): + assert self._eval_op("in", 2, [1, 2, 3]).fired_rules == ["r"] + + def test_in_list_miss(self): + assert self._eval_op("in", 5, [1, 2, 3]).fired_rules == [] + + def test_not_in_list(self): + assert self._eval_op("not_in", 5, [1, 2, 3]).fired_rules == ["r"] + + def test_not_in_list_miss(self): + assert self._eval_op("not_in", 2, [1, 2, 3]).fired_rules == [] + + def test_stale(self): + """Fact is stale when timestamp age exceeds threshold.""" + m = _model( + facts=[_fact("x", stale_after_ms=100)], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "stale", "value": 100}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 42) + ev.set_timestamp("x", 0) + ev.set_snapshot_timestamp(200) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_not_stale(self): + """Fact is not stale when recently updated.""" + m = _model( + facts=[_fact("x", stale_after_ms=100)], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "not_stale", "value": 100}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 42) + ev.set_timestamp("x", 150) + ev.set_snapshot_timestamp(200) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_changed(self): + """Changed detects when a fact value differs from its previous value.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "changed", "value": 0}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 10) + # First eval: prev is default (0), current is 10 → changed + r = ev.eval() + assert r.fired_rules == ["r"] + # Second eval: prev is now 10, current still 10 → not changed + r2 = ev.eval() + assert r2.fired_rules == [] + + def test_delta_gt(self): + """Delta > threshold.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "delta_gt", "value": 5}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 10) + r = ev.eval() + assert r.fired_rules == ["r"] # |10 - 0| = 10 > 5 + + def test_delta_lt(self): + """Delta < threshold.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "delta_lt", "value": 5}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 2) + r = ev.eval() + assert r.fired_rules == ["r"] # |2 - 0| = 2 < 5 + + +# =================================================================== +# 3. CONDITION GROUPS +# =================================================================== + + +class TestConditionGroups: + def test_all_short_circuit(self): + """ALL group short-circuits on first false.""" + m = _model( + facts=[_fact("x"), _fact("y")], + rules=[_rule("r", when={"all": [ + {"fact": "x", "op": ">", "value": 10}, + {"fact": "y", "op": ">", "value": 10}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 5) # fails + ev.set_fact("y", 20) + r = ev.eval() + assert r.fired_rules == [] + + def test_any_short_circuit(self): + """ANY group short-circuits on first true.""" + m = _model( + facts=[_fact("x"), _fact("y")], + rules=[_rule("r", when={"any": [ + {"fact": "x", "op": ">", "value": 10}, + {"fact": "y", "op": ">", "value": 10}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 20) + ev.set_fact("y", 5) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_not_group(self): + """NOT group inverts: fires if inner condition is false.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r", when={"not": [ + {"fact": "x", "op": "==", "value": 1}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 0) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_not_group_fails(self): + """NOT group fails when inner condition is true.""" + m = _model( + facts=[_fact("x")], + rules=[_rule("r", when={"not": [ + {"fact": "x", "op": "==", "value": 1}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 1) + r = ev.eval() + assert r.fired_rules == [] + + +# =================================================================== +# 4. SAFETY GUARD ORDERING +# =================================================================== + + +class TestSafetyGuardOrdering: + def test_safety_guard_fires_before_inference(self): + """Safety guard rules must execute before inference rules.""" + m = _model( + facts=[_fact("x")], + rules=[ + _rule("infer_first_alpha", rclass="inference"), + _rule("safety_second_alpha", rclass="safety_guard"), + ], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + # Both fire (unconditional), but safety_guard must be first. + assert r.fired_rules[0] == "safety_second_alpha" + assert r.fired_rules[1] == "infer_first_alpha" + + def test_full_class_ordering(self): + """All rule classes are evaluated in the correct priority order.""" + m = _model( + facts=[_fact("x")], + rules=[ + _rule("advisory_r", rclass="advisory"), + _rule("inference_r", rclass="inference"), + _rule("constraint_r", rclass="constraint"), + _rule("mode_guard_r", rclass="mode_guard"), + _rule("obligation_r", rclass="obligation"), + _rule("safety_guard_r", rclass="safety_guard"), + ], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert r.fired_rules == [ + "safety_guard_r", + "obligation_r", + "constraint_r", + "mode_guard_r", + "inference_r", + "advisory_r", + ] + + +# =================================================================== +# 5. EXPRESSION OPCODES (all 15) +# =================================================================== + + +class TestExpressionOpcodes: + def _eval_compute(self, facts, exprs, fact_values=None): + """Helper: run a rule with compute expressions and return fact values.""" + m = _model( + facts=facts, + rules=[_rule("r", then={"compute": exprs})], + ) + ev = ArbiterEvaluator(m) + if fact_values: + for k, v in fact_values.items(): + ev.set_fact(k, v) + ev.eval() + return ev._fact_values + + def test_assign_literal(self): + vals = self._eval_compute( + [_fact("out")], + [{"target": "out", "op": "assign", "left_literal": 42}], + ) + assert vals["out"] == 42 + + def test_assign_fact(self): + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "assign", "left": "a"}], + {"a": 99}, + ) + assert vals["out"] == 99 + + def test_add(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "add", "left": "a", "right": "b"}], + {"a": 100, "b": 200}, + ) + assert vals["out"] == 300 + + def test_sub(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "sub", "left": "a", "right": "b"}], + {"a": 100, "b": 30}, + ) + assert vals["out"] == 70 + + def test_mul(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "mul", "left": "a", "right": "b"}], + {"a": 7, "b": 6}, + ) + assert vals["out"] == 42 + + def test_div(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "div", "left": "a", "right": "b"}], + {"a": 100, "b": 3}, + ) + assert vals["out"] == 33 # truncate toward zero + + def test_div_by_zero(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "div", "left": "a", "right": "b"}], + {"a": 100, "b": 0}, + ) + assert vals["out"] == 0 + + def test_mod(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "mod", "left": "a", "right": "b"}], + {"a": 17, "b": 5}, + ) + assert vals["out"] == 2 + + def test_mod_negative(self): + """C-style mod: sign follows dividend.""" + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "mod", "left": "a", "right": "b"}], + {"a": -17, "b": 5}, + ) + assert vals["out"] == -2 + + def test_abs(self): + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "abs", "left": "a"}], + {"a": -42}, + ) + assert vals["out"] == 42 + + def test_negate(self): + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "negate", "left": "a"}], + {"a": 42}, + ) + assert vals["out"] == -42 + + def test_min(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "min", "left": "a", "right": "b"}], + {"a": 10, "b": 3}, + ) + assert vals["out"] == 3 + + def test_max(self): + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "max", "left": "a", "right": "b"}], + {"a": 10, "b": 3}, + ) + assert vals["out"] == 10 + + def test_clamp(self): + """clamp(left, lo=right, hi=scale).""" + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "clamp", "left": "a", "right_literal": -100, "scale": 100}], + {"a": 200}, + ) + assert vals["out"] == 100 # clamped to hi + + def test_clamp_lo(self): + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "clamp", "left": "a", "right_literal": -100, "scale": 100}], + {"a": -200}, + ) + assert vals["out"] == -100 # clamped to lo + + def test_shift_r(self): + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "shift_r", "left": "a", "right_literal": 2}], + {"a": 100}, + ) + assert vals["out"] == 25 + + def test_shift_l(self): + vals = self._eval_compute( + [_fact("a"), _fact("out")], + [{"target": "out", "op": "shift_l", "left": "a", "right_literal": 3}], + {"a": 5}, + ) + assert vals["out"] == 40 + + def test_scale(self): + """scale: target = (left * right) / scale.""" + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "scale", "left": "a", "right": "b", "scale": 1000}], + {"a": 5000, "b": 2500}, + ) + assert vals["out"] == 12500 # 5000*2500/1000 + + def test_scale_saturation(self): + """scale with large values should saturate to INT32_MAX.""" + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [{"target": "out", "op": "scale", "left": "a", "right": "b", "scale": 1}], + {"a": INT32_MAX, "b": 2}, + ) + assert vals["out"] == INT32_MAX + + def test_accumulate(self): + """accumulate: target = target + (left * right) / scale.""" + vals = self._eval_compute( + [_fact("a"), _fact("b"), _fact("out")], + [ + {"target": "out", "op": "assign", "left_literal": 100}, + {"target": "out", "op": "accumulate", "left": "a", "right": "b", "scale": 10}, + ], + {"a": 50, "b": 3}, + ) + assert vals["out"] == 115 # 100 + (50*3)/10 = 100 + 15 + + +# =================================================================== +# 6. INT32 SATURATION +# =================================================================== + + +class TestSaturation: + def test_saturate32_max(self): + assert _saturate32(INT32_MAX + 1) == INT32_MAX + + def test_saturate32_min(self): + assert _saturate32(INT32_MIN - 1) == INT32_MIN + + def test_add_overflow(self): + m = _model( + facts=[_fact("a"), _fact("b"), _fact("out")], + rules=[_rule("r", then={"compute": [ + {"target": "out", "op": "add", "left": "a", "right": "b"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("a", INT32_MAX) + ev.set_fact("b", 1) + ev.eval() + assert ev._fact_values["out"] == INT32_MAX + + def test_sub_underflow(self): + m = _model( + facts=[_fact("a"), _fact("b"), _fact("out")], + rules=[_rule("r", then={"compute": [ + {"target": "out", "op": "sub", "left": "a", "right": "b"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("a", INT32_MIN) + ev.set_fact("b", 1) + ev.eval() + assert ev._fact_values["out"] == INT32_MIN + + +# =================================================================== +# 7. MODE TRANSITIONS +# =================================================================== + + +class TestModeTransitions: + def test_mode_set(self): + m = _model( + facts=[_fact("x")], + modes=[{"id": "mode.a"}, {"id": "mode.b"}], + rules=[_rule("r", then={"set_mode": "mode.a"})], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert r.current_mode == "mode.a" + + def test_last_mode_wins(self): + """When multiple rules set mode, the last one in eval order wins.""" + m = _model( + facts=[_fact("x")], + modes=[{"id": "mode.a"}, {"id": "mode.b"}], + rules=[ + _rule("r1", then={"set_mode": "mode.a"}), + _rule("r2", then={"set_mode": "mode.b"}), + ], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert r.current_mode == "mode.b" + + +# =================================================================== +# 8. ACTION COLLECTION +# =================================================================== + + +class TestActionCollection: + def test_action_collected(self): + m = _model( + facts=[_fact("x")], + actions=[{"id": "act1", "type": "callback"}], + rules=[_rule("r", then={"action": "act1"})], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert r.requested_actions == ["act1"] + + +# =================================================================== +# 9. FAULT RAISE / CLEAR +# =================================================================== + + +class TestFaults: + def test_raise_fault(self): + m = _model( + facts=[_fact("x")], + rules=[_rule("r", then={"raise_fault": "fault.overheat"})], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert "fault.overheat" in r.raised_faults + + def test_clear_fault(self): + m = _model( + facts=[_fact("x")], + rules=[ + _rule("r1", then={"raise_fault": "fault.overheat"}), + _rule("r2", then={"clear_fault": "fault.overheat"}), + ], + ) + ev = ArbiterEvaluator(m) + r = ev.eval() + assert "fault.overheat" not in r.raised_faults + + def test_fault_persists_across_evals(self): + m = _model( + facts=[_fact("x")], + rules=[_rule("r1", when={"all": [{"fact": "x", "op": "==", "value": 1}]}, + then={"raise_fault": "fault.overheat"})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 1) + r1 = ev.eval() + assert "fault.overheat" in r1.raised_faults + # Second eval with x=0 — rule doesn't fire but fault persists. + ev.set_fact("x", 0) + r2 = ev.eval() + assert "fault.overheat" in r2.raised_faults + + +# =================================================================== +# 10. DETERMINISM +# =================================================================== + + +class TestDeterminism: + def test_same_input_same_output(self): + """Running the same model with the same facts must produce identical results.""" + m = _model( + facts=[_fact("x"), _fact("y")], + rules=[ + _rule("r1", when={"all": [{"fact": "x", "op": ">", "value": 5}]}, + then={"set_mode": "active"}), + _rule("r2", when={"all": [{"fact": "y", "op": "==", "value": 1}]}), + ], + modes=[{"id": "active"}], + ) + results = [] + for _ in range(5): + ev = ArbiterEvaluator(m) + ev.set_fact("x", 10) + ev.set_fact("y", 1) + results.append(ev.eval().to_dict()) + + # All 5 runs should be identical. + for r in results[1:]: + assert r == results[0] + + +# =================================================================== +# 11. TRACE +# =================================================================== + + +class TestTrace: + def test_trace_records_all_rules(self): + m = _model( + facts=[_fact("x")], + rules=[ + _rule("r1", when={"all": [{"fact": "x", "op": ">", "value": 5}]}), + _rule("r2", when={"all": [{"fact": "x", "op": "<", "value": 5}]}), + ], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 10) + r = ev.eval() + assert len(r.trace) == 2 + ids = [t.rule_id for t in r.trace] + assert "r1" in ids + assert "r2" in ids + + def test_trace_records_fired_status(self): + m = _model( + facts=[_fact("x")], + rules=[_rule("r1", when={"all": [{"fact": "x", "op": "==", "value": 1}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 1) + r = ev.eval() + assert r.trace[0].fired is True + + ev2 = ArbiterEvaluator(m) + ev2.set_fact("x", 0) + r2 = ev2.eval() + assert r2.trace[0].fired is False + + +# =================================================================== +# 12. EVAL RESULT SERIALIZATION +# =================================================================== + + +class TestEvalResultSerialization: + def test_to_dict(self): + r = EvalResult( + fired_rules=["r1"], + current_mode="active", + raised_faults={"fault.a", "fault.b"}, + requested_actions=["act1"], + op_count=5, + trace=[TraceEntry("r1", True, "reason")], + ) + d = r.to_dict() + assert d["fired_rules"] == ["r1"] + assert d["current_mode"] == "active" + assert d["raised_faults"] == ["fault.a", "fault.b"] # sorted + assert d["op_count"] == 5 + # JSON-roundtrip should work. + assert json.loads(json.dumps(d)) == d + + +# =================================================================== +# 13. SAMPLE MODEL EVALUATION +# =================================================================== + + +class TestSampleModels: + def test_battery_critical(self): + """Battery model: voltage < 3000 triggers critical safety guard.""" + import yaml + + model_path = SAMPLES_DIR / "battery_policy" / "models" / "battery.arb.yaml" + data = yaml.safe_load(model_path.read_text(encoding="utf-8")) + ev = ArbiterEvaluator(data) + ev.set_fact("battery.voltage_mv", 2900) + ev.set_fact("battery.current_ma", 0) + ev.set_fact("battery.temp_c", 25) + ev.set_fact("charger.enabled", False) + r = ev.eval() + # Safety guards fire first. critical_battery should fire. + assert "rule.critical_battery" in r.fired_rules + assert r.current_mode in ("mode.critical", "mode.low_battery") + + def test_battery_normal(self): + """Battery model: normal voltage, charger on → charging mode.""" + import yaml + + model_path = SAMPLES_DIR / "battery_policy" / "models" / "battery.arb.yaml" + data = yaml.safe_load(model_path.read_text(encoding="utf-8")) + ev = ArbiterEvaluator(data) + ev.set_fact("battery.voltage_mv", 3800) + ev.set_fact("battery.current_ma", 500) + ev.set_fact("battery.temp_c", 25) + ev.set_fact("charger.enabled", True) + r = ev.eval() + assert "rule.charging" in r.fired_rules + assert r.current_mode == "mode.charging" + + +# =================================================================== +# 14. PID COMPUTE MODEL +# =================================================================== + + +class TestPidModel: + def test_pid_compute(self): + """PID model: enabled + valid sensor → compute PID terms.""" + import yaml + + model_path = SAMPLES_DIR / "pid_controller" / "models" / "pid_engine.arb.yaml" + data = yaml.safe_load(model_path.read_text(encoding="utf-8")) + ev = ArbiterEvaluator(data) + ev.set_fact("in.enable", True) + ev.set_fact("in.sensor_valid", True) + ev.set_fact("in.process_value", 90000) + ev.set_fact("in.setpoint", 100000) + ev.set_fact("in.dt_ms", 10) + ev.set_fact("gain.kp", 2500) + ev.set_fact("gain.ki", 100) + ev.set_fact("gain.kd", 800) + r = ev.eval() + assert "10_pid.compute" in r.fired_rules + # Error should be 100000 - 90000 = 10000 + assert ev._fact_values["pid.error"] == 10000 + # P-term should be (10000 * 2500) / 1000 = 25000 + assert ev._fact_values["pid.p_term"] == 25000 + + def test_pid_sensor_fault(self): + """PID model: sensor invalid → safety guard fires, output zeroed.""" + import yaml + + model_path = SAMPLES_DIR / "pid_controller" / "models" / "pid_engine.arb.yaml" + data = yaml.safe_load(model_path.read_text(encoding="utf-8")) + ev = ArbiterEvaluator(data) + ev.set_fact("in.enable", True) + ev.set_fact("in.sensor_valid", False) + ev.set_fact("in.process_value", 90000) + ev.set_fact("in.setpoint", 100000) + ev.set_fact("in.dt_ms", 10) + r = ev.eval() + assert "01_fault.sensor" in r.fired_rules + assert r.current_mode == "mode.sensor_fault" + assert ev._fact_values["pid.output"] == 0 + + +# =================================================================== +# 15. STALENESS EDGE CASES +# =================================================================== + + +class TestStalenessEdgeCases: + def test_never_written_is_stale(self): + """A fact that was never written should be considered stale.""" + m = _model( + facts=[_fact("x", stale_after_ms=100)], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "stale", "value": 100}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_snapshot_timestamp(1000) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_exact_threshold_not_stale(self): + """Age exactly equal to threshold should NOT be stale (> not >=).""" + m = _model( + facts=[_fact("x", stale_after_ms=100)], + rules=[_rule("r", when={"all": [{"fact": "x", "op": "stale", "value": 100}]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("x", 42) + ev.set_timestamp("x", 0) + ev.set_snapshot_timestamp(100) + r = ev.eval() + assert r.fired_rules == [] # age = 100, threshold = 100 → not stale + + +# =================================================================== +# 16. SET_FACT FROM STRING (CLI COERCION) +# =================================================================== + + +class TestSetFactCoercion: + def test_string_true(self): + m = _model(facts=[_fact("x", "bool")]) + ev = ArbiterEvaluator(m) + ev.set_fact("x", "true") + assert ev._fact_values["x"] == 1 + + def test_string_false(self): + m = _model(facts=[_fact("x", "bool")]) + ev = ArbiterEvaluator(m) + ev.set_fact("x", "false") + assert ev._fact_values["x"] == 0 + + def test_string_int(self): + m = _model(facts=[_fact("x")]) + ev = ArbiterEvaluator(m) + ev.set_fact("x", "42") + assert ev._fact_values["x"] == 42 + + def test_unknown_fact_raises(self): + m = _model(facts=[_fact("x")]) + ev = ArbiterEvaluator(m) + with pytest.raises(KeyError): + ev.set_fact("nonexistent", 1)