diff --git a/openkb/agent/compiler.py b/openkb/agent/compiler.py index d202fc4e..bf15a1c3 100644 --- a/openkb/agent/compiler.py +++ b/openkb/agent/compiler.py @@ -6,6 +6,13 @@ Step 3: A + summary → concepts plan (create/update/related). Step 4: Concurrent LLM calls (A cached) → generate new + rewrite updated concepts. Step 5: Code adds cross-ref links to related concepts, updates index. + +Anthropic prompt caching is enabled via ``cache_control`` markers at two +breakpoints: end of the document message (caches system + doc across all +N+M+2 calls) and end of the assistant summary message (caches the additional +summary prefix across N+M concept-generation calls). Providers that do not +support cache_control receive a normalized list-of-blocks content payload, +which LiteLLM passes through cleanly. """ from __future__ import annotations @@ -131,6 +138,17 @@ # LLM helpers # --------------------------------------------------------------------------- +def _cached_text(text: str) -> list[dict]: + """Wrap a text payload into a content-block list with an Anthropic + ephemeral cache_control marker. + + LiteLLM passes the marker through to Anthropic (and OpenRouter → + Anthropic). For providers that ignore cache_control, the list-of-blocks + payload remains a valid OpenAI-compatible content shape. + """ + return [{"type": "text", "text": text, "cache_control": {"type": "ephemeral"}}] + + class _Spinner: """Animated dots spinner that runs in a background thread.""" @@ -168,15 +186,23 @@ def _format_usage(elapsed: float, usage) -> str: def _fmt_messages(messages: list[dict], max_content: int = 200) -> str: - """Format messages for debug output, truncating long content.""" + """Format messages for debug output, truncating long content. + + Accepts both plain-string content and the list-of-blocks shape used by + cache_control-tagged messages (joins all text blocks for preview). + """ parts = [] for msg in messages: role = msg["role"] - content = msg["content"] - if len(content) > max_content: - preview = content[:max_content] + f"... ({len(content)} chars)" + raw = msg["content"] + if isinstance(raw, list): + text = "".join(b.get("text", "") for b in raw if isinstance(b, dict)) else: - preview = content + text = raw + if len(text) > max_content: + preview = text[:max_content] + f"... ({len(text)} chars)" + else: + preview = text parts.append(f" [{role}] {preview}") return "\n".join(parts) @@ -199,13 +225,15 @@ def _llm_call(model: str, messages: list[dict], step_name: str, **kwargs) -> str return content.strip() -async def _llm_call_async(model: str, messages: list[dict], step_name: str) -> str: +async def _llm_call_async(model: str, messages: list[dict], step_name: str, **kwargs) -> str: """Async LLM call with timing output and debug logging.""" logger.debug("LLM request [%s]:\n%s", step_name, _fmt_messages(messages)) + if kwargs: + logger.debug("LLM kwargs [%s]: %s", step_name, kwargs) t0 = time.time() - response = await litellm.acompletion(model=model, messages=messages) + response = await litellm.acompletion(model=model, messages=messages, **kwargs) content = response.choices[0].message.content or "" elapsed = time.time() - t0 @@ -587,10 +615,14 @@ async def _compile_concepts( # --- Step 2: Get concepts plan (A cached) --- concept_briefs = _read_concept_briefs(wiki_dir) + # Second cache breakpoint: end of the assistant summary message. Covers + # (system + doc + summary) for the plan call and every concept call. + summary_msg = {"role": "assistant", "content": _cached_text(summary)} + plan_raw = _llm_call(model, [ system_msg, doc_msg, - {"role": "assistant", "content": summary}, + summary_msg, {"role": "user", "content": _CONCEPTS_PLAN_USER.format( concept_briefs=concept_briefs, )}, @@ -632,7 +664,7 @@ async def _gen_create(concept: dict) -> tuple[str, str, bool, str]: raw = await _llm_call_async(model, [ system_msg, doc_msg, - {"role": "assistant", "content": summary}, + summary_msg, {"role": "user", "content": _CONCEPT_PAGE_USER.format( title=title, doc_name=doc_name, update_instruction="", @@ -663,7 +695,7 @@ async def _gen_update(concept: dict) -> tuple[str, str, bool, str]: raw = await _llm_call_async(model, [ system_msg, doc_msg, - {"role": "assistant", "content": summary}, + summary_msg, {"role": "user", "content": _CONCEPT_UPDATE_USER.format( title=title, doc_name=doc_name, existing_content=existing_content, @@ -741,13 +773,15 @@ async def compile_short_doc( schema_md = get_agents_md(wiki_dir) content = source_path.read_text(encoding="utf-8") - # Base context A: system + document + # Base context A: system + document. cache_control marker on the doc + # message creates a cache breakpoint that covers (system + doc) for + # every downstream call (summary, concepts-plan, every concept page). system_msg = {"role": "system", "content": _SYSTEM_TEMPLATE.format( schema_md=schema_md, language=language, )} - doc_msg = {"role": "user", "content": _SUMMARY_USER.format( + doc_msg = {"role": "user", "content": _cached_text(_SUMMARY_USER.format( doc_name=doc_name, content=content, - )} + ))} # --- Step 1: Generate summary --- summary_raw = _llm_call(model, [system_msg, doc_msg], "summary") @@ -792,13 +826,14 @@ async def compile_long_doc( schema_md = get_agents_md(wiki_dir) summary_content = summary_path.read_text(encoding="utf-8") - # Base context A + # Base context A. cache_control marker on the doc message creates a + # cache breakpoint covering (system + doc) for every concept call. system_msg = {"role": "system", "content": _SYSTEM_TEMPLATE.format( schema_md=schema_md, language=language, )} - doc_msg = {"role": "user", "content": _LONG_DOC_SUMMARY_USER.format( + doc_msg = {"role": "user", "content": _cached_text(_LONG_DOC_SUMMARY_USER.format( doc_name=doc_name, doc_id=doc_id, content=summary_content, - )} + ))} # --- Step 1: Generate overview --- overview = _llm_call(model, [system_msg, doc_msg], "overview") diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 2a2e82dc..cb02efc0 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -651,6 +651,131 @@ async def test_handles_bad_json(self, tmp_path): assert (wiki / "summaries" / "doc.md").exists() +class TestCacheControl: + """Verify cache_control breakpoints are emitted on the right messages + so Anthropic prompt caching can hit on every reuse of the base context. + """ + + @staticmethod + def _has_cache_breakpoint(message: dict) -> bool: + content = message.get("content") + if not isinstance(content, list): + return False + return any( + isinstance(b, dict) and b.get("cache_control", {}).get("type") == "ephemeral" + for b in content + ) + + @pytest.mark.asyncio + async def test_short_doc_marks_doc_and_summary(self, tmp_path): + wiki = tmp_path / "wiki" + (wiki / "sources").mkdir(parents=True) + (wiki / "summaries").mkdir(parents=True) + (wiki / "concepts").mkdir(parents=True) + (wiki / "index.md").write_text( + "# Index\n\n## Documents\n\n## Concepts\n", encoding="utf-8", + ) + src = wiki / "sources" / "doc.md" + src.write_text("Body text about caching.", encoding="utf-8") + (tmp_path / ".openkb").mkdir() + + summary_response = json.dumps({"brief": "B", "content": "summary body"}) + plan_response = json.dumps({ + "create": [{"name": "topic", "title": "Topic"}], + "update": [], "related": [], + }) + concept_response = json.dumps({"brief": "C", "content": "page body"}) + + captured_sync_calls: list[list[dict]] = [] + captured_async_calls: list[list[dict]] = [] + + sync_responses = [summary_response, plan_response] + + def sync_side_effect(*args, **kwargs): + captured_sync_calls.append(kwargs["messages"]) + idx = min(len(captured_sync_calls) - 1, len(sync_responses) - 1) + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = sync_responses[idx] + mock_resp.usage = MagicMock(prompt_tokens=1, completion_tokens=1) + mock_resp.usage.prompt_tokens_details = None + return mock_resp + + async def async_side_effect(*args, **kwargs): + captured_async_calls.append(kwargs["messages"]) + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = concept_response + mock_resp.usage = MagicMock(prompt_tokens=1, completion_tokens=1) + mock_resp.usage.prompt_tokens_details = None + return mock_resp + + with patch("openkb.agent.compiler.litellm") as mock_litellm: + mock_litellm.completion = MagicMock(side_effect=sync_side_effect) + mock_litellm.acompletion = AsyncMock(side_effect=async_side_effect) + await compile_short_doc("doc", src, tmp_path, "anthropic/claude-sonnet-4-5") + + # Step 1 (summary): doc_msg carries the breakpoint. + summary_call = captured_sync_calls[0] + assert summary_call[0]["role"] == "system" + assert summary_call[1]["role"] == "user" + assert self._has_cache_breakpoint(summary_call[1]), ( + "doc_msg in summary call must carry an ephemeral cache_control marker" + ) + + # Step 2 (plan): doc_msg AND assistant summary both carry breakpoints. + plan_call = captured_sync_calls[1] + assert self._has_cache_breakpoint(plan_call[1]) + assert plan_call[2]["role"] == "assistant" + assert self._has_cache_breakpoint(plan_call[2]), ( + "assistant summary in plan call must carry a cache_control marker" + ) + + # Step 3 (concept generation): same two breakpoints reused. + assert captured_async_calls, "expected at least one async concept call" + concept_call = captured_async_calls[0] + assert self._has_cache_breakpoint(concept_call[1]) + assert self._has_cache_breakpoint(concept_call[2]) + + @pytest.mark.asyncio + async def test_long_doc_marks_doc_message(self, tmp_path): + wiki = tmp_path / "wiki" + (wiki / "summaries").mkdir(parents=True) + (wiki / "concepts").mkdir(parents=True) + (wiki / "index.md").write_text( + "# Index\n\n## Documents\n\n## Concepts\n", encoding="utf-8", + ) + sp = wiki / "summaries" / "big.md" + sp.write_text("PageIndex tree summary.", encoding="utf-8") + (tmp_path / ".openkb").mkdir() + + captured: list[list[dict]] = [] + plan_response = json.dumps({"create": [], "update": [], "related": []}) + + def sync_side_effect(*args, **kwargs): + captured.append(kwargs["messages"]) + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + # First call: overview (plain text); second: plan (JSON). + mock_resp.choices[0].message.content = ( + "Overview text" if len(captured) == 1 else plan_response + ) + mock_resp.usage = MagicMock(prompt_tokens=1, completion_tokens=1) + mock_resp.usage.prompt_tokens_details = None + return mock_resp + + with patch("openkb.agent.compiler.litellm") as mock_litellm: + mock_litellm.completion = MagicMock(side_effect=sync_side_effect) + mock_litellm.acompletion = AsyncMock() + await compile_long_doc( + "big", sp, "doc-id-1", tmp_path, "anthropic/claude-sonnet-4-5", + ) + + overview_call = captured[0] + assert overview_call[1]["role"] == "user" + assert self._has_cache_breakpoint(overview_call[1]) + + class TestCompileLongDoc: @pytest.mark.asyncio async def test_full_pipeline(self, tmp_path):