Skip to content

Commit 296d2b1

Browse files
committed
test(ai): add regression tests for AsyncStreamWrapper
1 parent 0eabcfe commit 296d2b1

1 file changed

Lines changed: 127 additions & 0 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Regression tests for AsyncStreamWrapper.
2+
3+
Ensures that PostHog AI streaming wrappers return objects that support both
4+
the async iterator protocol (``async for``) and the async context manager
5+
protocol (``async with``), as required by libraries such as pydantic-ai.
6+
7+
Issue: https://github.com/PostHog/posthog-python/issues/393
8+
"""
9+
10+
import pytest
11+
12+
from posthog.ai.stream import AsyncStreamWrapper
13+
14+
15+
# ---------------------------------------------------------------------------
16+
# Helpers
17+
# ---------------------------------------------------------------------------
18+
19+
20+
async def _make_gen(items):
21+
"""Simple async generator that yields the given items."""
22+
for item in items:
23+
yield item
24+
25+
26+
# ---------------------------------------------------------------------------
27+
# Tests
28+
# ---------------------------------------------------------------------------
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_async_for_iteration():
33+
"""AsyncStreamWrapper must yield all items when used with ``async for``."""
34+
wrapper = AsyncStreamWrapper(_make_gen([1, 2, 3]))
35+
result = []
36+
async for item in wrapper:
37+
result.append(item)
38+
assert result == [1, 2, 3]
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_async_context_manager_protocol():
43+
"""AsyncStreamWrapper must support ``async with`` without raising TypeError."""
44+
wrapper = AsyncStreamWrapper(_make_gen(["a", "b"]))
45+
46+
# This is the call pattern that pydantic-ai uses and that previously raised:
47+
# TypeError: 'async_generator' object does not support the asynchronous
48+
# context manager protocol
49+
async with wrapper as stream:
50+
result = []
51+
async for chunk in stream:
52+
result.append(chunk)
53+
54+
assert result == ["a", "b"]
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_context_manager_returns_self():
59+
"""``async with wrapper as w`` should bind the wrapper itself."""
60+
wrapper = AsyncStreamWrapper(_make_gen([]))
61+
async with wrapper as w:
62+
assert w is wrapper
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_finally_block_runs_on_early_exit():
67+
"""The underlying generator's finally block must run even when the caller
68+
breaks out of the loop early (i.e. doesn't fully exhaust the generator)."""
69+
finally_ran = []
70+
71+
async def gen_with_finally():
72+
try:
73+
for i in range(10):
74+
yield i
75+
finally:
76+
finally_ran.append(True)
77+
78+
wrapper = AsyncStreamWrapper(gen_with_finally())
79+
async with wrapper as stream:
80+
async for chunk in stream:
81+
if chunk == 2:
82+
break # early exit
83+
84+
# __aexit__ must have called aclose(), triggering the finally block
85+
assert finally_ran == [True], "finally block in generator did not run on early exit"
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_finally_block_runs_on_full_exhaustion():
90+
"""The underlying generator's finally block must also run on normal
91+
exhaustion (``aclose()`` on an exhausted generator is a no-op)."""
92+
finally_ran = []
93+
94+
async def gen_with_finally():
95+
try:
96+
yield 1
97+
yield 2
98+
finally:
99+
finally_ran.append(True)
100+
101+
wrapper = AsyncStreamWrapper(gen_with_finally())
102+
async with wrapper as stream:
103+
async for _ in stream:
104+
pass
105+
106+
assert finally_ran == [True]
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_attribute_proxy():
111+
"""Attributes not on AsyncStreamWrapper itself should be forwarded to the
112+
underlying generator (for provider-specific metadata access)."""
113+
114+
class FakeStream:
115+
extra_attr = "hello"
116+
117+
def __aiter__(self):
118+
return self
119+
120+
async def __anext__(self):
121+
raise StopAsyncIteration
122+
123+
async def aclose(self):
124+
pass
125+
126+
wrapper = AsyncStreamWrapper(FakeStream()) # type: ignore[arg-type]
127+
assert wrapper.extra_attr == "hello"

0 commit comments

Comments
 (0)