Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions openkb/agent/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
Step 3: A + summary → concepts plan (create/update/related).
Step 4: Concurrent LLM calls (A cached) → generate new + rewrite updated concepts.
Step 5: Code adds cross-ref links to related concepts, updates index.

Anthropic prompt caching is enabled via ``cache_control`` markers at two
breakpoints: end of the document message (caches system + doc across all
N+M+2 calls) and end of the assistant summary message (caches the additional
summary prefix across N+M concept-generation calls). Providers that do not
support cache_control receive a normalized list-of-blocks content payload,
which LiteLLM passes through cleanly.
"""
from __future__ import annotations

Expand Down Expand Up @@ -131,6 +138,17 @@
# LLM helpers
# ---------------------------------------------------------------------------

def _cached_text(text: str) -> list[dict]:
"""Wrap a text payload into a content-block list with an Anthropic
ephemeral cache_control marker.

LiteLLM passes the marker through to Anthropic (and OpenRouter →
Anthropic). For providers that ignore cache_control, the list-of-blocks
payload remains a valid OpenAI-compatible content shape.
"""
return [{"type": "text", "text": text, "cache_control": {"type": "ephemeral"}}]


class _Spinner:
"""Animated dots spinner that runs in a background thread."""

Expand Down Expand Up @@ -168,15 +186,23 @@ def _format_usage(elapsed: float, usage) -> str:


def _fmt_messages(messages: list[dict], max_content: int = 200) -> str:
"""Format messages for debug output, truncating long content."""
"""Format messages for debug output, truncating long content.

Accepts both plain-string content and the list-of-blocks shape used by
cache_control-tagged messages (joins all text blocks for preview).
"""
parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if len(content) > max_content:
preview = content[:max_content] + f"... ({len(content)} chars)"
raw = msg["content"]
if isinstance(raw, list):
text = "".join(b.get("text", "") for b in raw if isinstance(b, dict))
else:
preview = content
text = raw
if len(text) > max_content:
preview = text[:max_content] + f"... ({len(text)} chars)"
else:
preview = text
parts.append(f" [{role}] {preview}")
return "\n".join(parts)

Expand All @@ -199,13 +225,15 @@ def _llm_call(model: str, messages: list[dict], step_name: str, **kwargs) -> str
return content.strip()


async def _llm_call_async(model: str, messages: list[dict], step_name: str) -> str:
async def _llm_call_async(model: str, messages: list[dict], step_name: str, **kwargs) -> str:
"""Async LLM call with timing output and debug logging."""
logger.debug("LLM request [%s]:\n%s", step_name, _fmt_messages(messages))
if kwargs:
logger.debug("LLM kwargs [%s]: %s", step_name, kwargs)

t0 = time.time()

response = await litellm.acompletion(model=model, messages=messages)
response = await litellm.acompletion(model=model, messages=messages, **kwargs)
content = response.choices[0].message.content or ""

elapsed = time.time() - t0
Expand Down Expand Up @@ -587,10 +615,14 @@ async def _compile_concepts(
# --- Step 2: Get concepts plan (A cached) ---
concept_briefs = _read_concept_briefs(wiki_dir)

# Second cache breakpoint: end of the assistant summary message. Covers
# (system + doc + summary) for the plan call and every concept call.
summary_msg = {"role": "assistant", "content": _cached_text(summary)}

plan_raw = _llm_call(model, [
system_msg,
doc_msg,
{"role": "assistant", "content": summary},
summary_msg,
{"role": "user", "content": _CONCEPTS_PLAN_USER.format(
concept_briefs=concept_briefs,
)},
Expand Down Expand Up @@ -632,7 +664,7 @@ async def _gen_create(concept: dict) -> tuple[str, str, bool, str]:
raw = await _llm_call_async(model, [
system_msg,
doc_msg,
{"role": "assistant", "content": summary},
summary_msg,
{"role": "user", "content": _CONCEPT_PAGE_USER.format(
title=title, doc_name=doc_name,
update_instruction="",
Expand Down Expand Up @@ -663,7 +695,7 @@ async def _gen_update(concept: dict) -> tuple[str, str, bool, str]:
raw = await _llm_call_async(model, [
system_msg,
doc_msg,
{"role": "assistant", "content": summary},
summary_msg,
{"role": "user", "content": _CONCEPT_UPDATE_USER.format(
title=title, doc_name=doc_name,
existing_content=existing_content,
Expand Down Expand Up @@ -741,13 +773,15 @@ async def compile_short_doc(
schema_md = get_agents_md(wiki_dir)
content = source_path.read_text(encoding="utf-8")

# Base context A: system + document
# Base context A: system + document. cache_control marker on the doc
# message creates a cache breakpoint that covers (system + doc) for
# every downstream call (summary, concepts-plan, every concept page).
system_msg = {"role": "system", "content": _SYSTEM_TEMPLATE.format(
schema_md=schema_md, language=language,
)}
doc_msg = {"role": "user", "content": _SUMMARY_USER.format(
doc_msg = {"role": "user", "content": _cached_text(_SUMMARY_USER.format(
doc_name=doc_name, content=content,
)}
))}

# --- Step 1: Generate summary ---
summary_raw = _llm_call(model, [system_msg, doc_msg], "summary")
Expand Down Expand Up @@ -792,13 +826,14 @@ async def compile_long_doc(
schema_md = get_agents_md(wiki_dir)
summary_content = summary_path.read_text(encoding="utf-8")

# Base context A
# Base context A. cache_control marker on the doc message creates a
# cache breakpoint covering (system + doc) for every concept call.
system_msg = {"role": "system", "content": _SYSTEM_TEMPLATE.format(
schema_md=schema_md, language=language,
)}
doc_msg = {"role": "user", "content": _LONG_DOC_SUMMARY_USER.format(
doc_msg = {"role": "user", "content": _cached_text(_LONG_DOC_SUMMARY_USER.format(
doc_name=doc_name, doc_id=doc_id, content=summary_content,
)}
))}

# --- Step 1: Generate overview ---
overview = _llm_call(model, [system_msg, doc_msg], "overview")
Expand Down
125 changes: 125 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,131 @@ async def test_handles_bad_json(self, tmp_path):
assert (wiki / "summaries" / "doc.md").exists()


class TestCacheControl:
"""Verify cache_control breakpoints are emitted on the right messages
so Anthropic prompt caching can hit on every reuse of the base context.
"""

@staticmethod
def _has_cache_breakpoint(message: dict) -> bool:
content = message.get("content")
if not isinstance(content, list):
return False
return any(
isinstance(b, dict) and b.get("cache_control", {}).get("type") == "ephemeral"
for b in content
)

@pytest.mark.asyncio
async def test_short_doc_marks_doc_and_summary(self, tmp_path):
wiki = tmp_path / "wiki"
(wiki / "sources").mkdir(parents=True)
(wiki / "summaries").mkdir(parents=True)
(wiki / "concepts").mkdir(parents=True)
(wiki / "index.md").write_text(
"# Index\n\n## Documents\n\n## Concepts\n", encoding="utf-8",
)
src = wiki / "sources" / "doc.md"
src.write_text("Body text about caching.", encoding="utf-8")
(tmp_path / ".openkb").mkdir()

summary_response = json.dumps({"brief": "B", "content": "summary body"})
plan_response = json.dumps({
"create": [{"name": "topic", "title": "Topic"}],
"update": [], "related": [],
})
concept_response = json.dumps({"brief": "C", "content": "page body"})

captured_sync_calls: list[list[dict]] = []
captured_async_calls: list[list[dict]] = []

sync_responses = [summary_response, plan_response]

def sync_side_effect(*args, **kwargs):
captured_sync_calls.append(kwargs["messages"])
idx = min(len(captured_sync_calls) - 1, len(sync_responses) - 1)
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
mock_resp.choices[0].message.content = sync_responses[idx]
mock_resp.usage = MagicMock(prompt_tokens=1, completion_tokens=1)
mock_resp.usage.prompt_tokens_details = None
return mock_resp

async def async_side_effect(*args, **kwargs):
captured_async_calls.append(kwargs["messages"])
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
mock_resp.choices[0].message.content = concept_response
mock_resp.usage = MagicMock(prompt_tokens=1, completion_tokens=1)
mock_resp.usage.prompt_tokens_details = None
return mock_resp

with patch("openkb.agent.compiler.litellm") as mock_litellm:
mock_litellm.completion = MagicMock(side_effect=sync_side_effect)
mock_litellm.acompletion = AsyncMock(side_effect=async_side_effect)
await compile_short_doc("doc", src, tmp_path, "anthropic/claude-sonnet-4-5")

# Step 1 (summary): doc_msg carries the breakpoint.
summary_call = captured_sync_calls[0]
assert summary_call[0]["role"] == "system"
assert summary_call[1]["role"] == "user"
assert self._has_cache_breakpoint(summary_call[1]), (
"doc_msg in summary call must carry an ephemeral cache_control marker"
)

# Step 2 (plan): doc_msg AND assistant summary both carry breakpoints.
plan_call = captured_sync_calls[1]
assert self._has_cache_breakpoint(plan_call[1])
assert plan_call[2]["role"] == "assistant"
assert self._has_cache_breakpoint(plan_call[2]), (
"assistant summary in plan call must carry a cache_control marker"
)

# Step 3 (concept generation): same two breakpoints reused.
assert captured_async_calls, "expected at least one async concept call"
concept_call = captured_async_calls[0]
assert self._has_cache_breakpoint(concept_call[1])
assert self._has_cache_breakpoint(concept_call[2])

@pytest.mark.asyncio
async def test_long_doc_marks_doc_message(self, tmp_path):
wiki = tmp_path / "wiki"
(wiki / "summaries").mkdir(parents=True)
(wiki / "concepts").mkdir(parents=True)
(wiki / "index.md").write_text(
"# Index\n\n## Documents\n\n## Concepts\n", encoding="utf-8",
)
sp = wiki / "summaries" / "big.md"
sp.write_text("PageIndex tree summary.", encoding="utf-8")
(tmp_path / ".openkb").mkdir()

captured: list[list[dict]] = []
plan_response = json.dumps({"create": [], "update": [], "related": []})

def sync_side_effect(*args, **kwargs):
captured.append(kwargs["messages"])
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
# First call: overview (plain text); second: plan (JSON).
mock_resp.choices[0].message.content = (
"Overview text" if len(captured) == 1 else plan_response
)
mock_resp.usage = MagicMock(prompt_tokens=1, completion_tokens=1)
mock_resp.usage.prompt_tokens_details = None
return mock_resp

with patch("openkb.agent.compiler.litellm") as mock_litellm:
mock_litellm.completion = MagicMock(side_effect=sync_side_effect)
mock_litellm.acompletion = AsyncMock()
await compile_long_doc(
"big", sp, "doc-id-1", tmp_path, "anthropic/claude-sonnet-4-5",
)

overview_call = captured[0]
assert overview_call[1]["role"] == "user"
assert self._has_cache_breakpoint(overview_call[1])


class TestCompileLongDoc:
@pytest.mark.asyncio
async def test_full_pipeline(self, tmp_path):
Expand Down