-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsession_manager.py
More file actions
85 lines (75 loc) · 3.07 KB
/
Copy pathsession_manager.py
File metadata and controls
85 lines (75 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""路由会话管理器 — 封装兼容性会话的创建、上下文应用与持久化."""
from __future__ import annotations
from typing import Any
from ..compat.canonical import (
CompatibilityTrace,
)
from ..compat.session_store import CompatSessionRecord, CompatSessionStore
from .tier import VendorTier
class RouteSessionManager:
"""管理单次路由请求的兼容性会话生命周期."""
def __init__(self, compat_session_store: CompatSessionStore | None = None) -> None:
self._store = compat_session_store
async def get_or_create_record(
self, session_key: str, trace_id: str
) -> tuple[CompatSessionRecord | None, bool]:
"""获取或创建兼容性会话记录.
Returns:
(record, is_new) — is_new 为 True 表示本次创建的新会话。
"""
if self._store is None:
return None, False
record = await self._store.get(session_key)
if record is not None:
return record, False
return CompatSessionRecord(session_key=session_key, trace_id=trace_id), True
def apply_compat_context(
self,
*,
tier: VendorTier,
canonical_request: Any,
decision: Any,
session_record: CompatSessionRecord | None,
) -> None:
provider_protocol = {
"copilot": "openai_chat_completions",
"antigravity": "gemini_generate_content",
"zhipu": "anthropic_messages",
"minimax": "anthropic_messages",
"kimi": "anthropic_messages",
"doubao": "anthropic_messages",
"xiaomi": "anthropic_messages",
"alibaba": "anthropic_messages",
"anthropic": "anthropic_messages",
}.get(tier.name, "unknown")
compat_trace = CompatibilityTrace(
trace_id=canonical_request.trace_id,
vendor=tier.name,
session_key=canonical_request.session_key,
provider_protocol=provider_protocol,
compat_mode=decision.status.value,
simulation_actions=list(decision.simulation_actions),
unsupported_semantics=list(decision.unsupported_semantics),
session_state_hits=1 if session_record else 0,
request_adaptations=[],
)
tier.vendor.set_compat_context(
trace=compat_trace, session_record=session_record
)
async def persist_session(
self,
trace: CompatibilityTrace | None,
session_record: CompatSessionRecord | None,
) -> None:
if self._store is None or trace is None or session_record is None:
return
provider_states = dict(session_record.provider_state)
provider_states[trace.vendor] = {
"compat_mode": trace.compat_mode,
"simulation_actions": trace.simulation_actions,
"unsupported_semantics": trace.unsupported_semantics,
"trace_id": trace.trace_id,
}
session_record.trace_id = trace.trace_id
session_record.provider_state = provider_states
await self._store.upsert(session_record)