Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self) -> None:
self.hw_address = t.EUI64.convert("11:22:33:44:55:66:77:88")
self.handlers: dict[str, Callable[[Any, int], Awaitable[Any]]] = {
"ping": self.on_ping,
"reset": self.on_status,
"configure": self.on_configure,
"get_network_info": self.on_get_network_info,
"get_hw_address": self.on_get_hw_address,
Expand Down Expand Up @@ -159,6 +160,38 @@ async def send_event_data(
{"type": "event", "id": request_id, "event": event, "data": data}
)

async def send_confirm(
self, request_id: int, *, next_hop: str | None = None, reason: str | None = None
) -> None:
if reason is not None:
data: dict[str, Any] = {
"id": request_id,
"status": "failed",
"reason": reason,
}
else:
data = {"id": request_id, "status": "confirmed", "next_hop": next_hop}

await self.ws.send_json(
{"type": "notification", "event": "send_confirm", "data": data}
)

async def aps_ack_confirm(
self, request_id: int, *, reason: str | None = None
) -> None:
if reason is not None:
data: dict[str, Any] = {
"id": request_id,
"status": "failed",
"reason": reason,
}
else:
data = {"id": request_id, "status": "confirmed"}

await self.ws.send_json(
{"type": "notification", "event": "aps_ack_confirm", "data": data}
)

async def send_notification(self, notification: commands.Notification) -> None:
await self.ws.send_json(
{
Expand Down Expand Up @@ -208,8 +241,10 @@ async def on_get_hw_address(
async def on_send_aps(
self, command: commands.SendAps, request_id: int
) -> commands.Status:
await self.send_event(request_id, "transmitted")
return commands.Status(status="delivered" if command.aps_ack else "sent")
await self.send_confirm(request_id)
if command.aps_ack:
await self.aps_ack_confirm(request_id)
return commands.Status(status="accepted")

async def on_energy_scan(
self, command: commands.EnergyScan, request_id: int
Expand Down
45 changes: 26 additions & 19 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
from collections.abc import AsyncIterator
from dataclasses import replace

import pytest
from zigpy.exceptions import DeliveryError
Expand Down Expand Up @@ -62,55 +63,61 @@ async def fail(command: commands.Ping, request_id: int) -> commands.Status:
await api.request(commands.Ping())


async def test_request_transmitted(
async def test_request_confirmed(api: RecordingApi, server: SyntheticZiggurat) -> None:
"""An APS-ack send resolves once the end-to-end APS ack arrives."""
await api.request_confirmed(SEND_APS)
assert server.sent(commands.SendAps)[-1].aps_seq == 55


async def test_request_confirmed_next_hop(
api: RecordingApi, server: SyntheticZiggurat
) -> None:
await api.request_transmitted(SEND_APS)
assert server.sent(commands.SendAps)[-1].aps_seq == 55
"""A no-ack unicast resolves on the local handoff."""
await api.request_confirmed(replace(SEND_APS, aps_ack=False))


async def test_request_transmitted_failure_before_transmission(
async def test_request_confirmed_rejected(
api: RecordingApi, server: SyntheticZiggurat
) -> None:
"""Stage two: the stack rejects the frame, so the send raises before any confirm."""

async def fail(command: commands.SendAps, request_id: int) -> commands.Status:
raise RpcError("transmit_failed", "channel busy")

server.handlers["send_aps"] = fail

with pytest.raises(DeliveryError, match="transmit_failed"):
await api.request_transmitted(SEND_APS)
await api.request_confirmed(SEND_APS)


async def test_late_delivery_failure_is_logged(
api: RecordingApi, server: SyntheticZiggurat, caplog: pytest.LogCaptureFixture
async def test_request_confirmed_failure(
api: RecordingApi, server: SyntheticZiggurat
) -> None:
"""The frame is handed off but the end-to-end APS ack never arrives."""

async def ack_timeout(
command: commands.SendAps, request_id: int
) -> commands.Status:
await server.send_event(request_id, "transmitted")
raise RpcError("aps_ack_timeout", "no ack")
await server.send_confirm(request_id)
await server.aps_ack_confirm(request_id, reason="APS ack timed out")
return commands.Status(status="accepted")

server.handlers["send_aps"] = ack_timeout

# Resolves at the `transmitted` stage; the terminal failure arrives later and is
# logged instead of raised
await api.request_transmitted(SEND_APS)

async with asyncio.timeout(1):
while "Delivery failed after transmission" not in caplog.text:
await asyncio.sleep(0.01)
with pytest.raises(DeliveryError, match="APS ack timed out"):
await api.request_confirmed(SEND_APS)


async def test_unsolicited_messages_are_ignored(
api: RecordingApi, server: SyntheticZiggurat, caplog: pytest.LogCaptureFixture
) -> None:
await server.send_raw("not json")
await server.send_raw('{"type": "response", "id": 9999, "result": {}}')
await server.send_raw('{"type": "event", "id": 9999, "event": "transmitted"}')
await server.send_raw('{"type": "event", "id": 9999, "event": "spurious"}')

# A `transmitted` event for a request that did not ask for one
# An unknown event for an in-flight request is ignored (only stream results match)
async def eager(command: commands.Ping, request_id: int) -> commands.Status:
await server.send_event(request_id, "transmitted")
await server.send_event(request_id, "spurious")
return commands.Status(status="pong")

server.handlers["ping"] = eager
Expand Down
Loading
Loading