Skip to content

Commit c6f41d7

Browse files
greatyaocursoragent
andcommitted
feat(chat): full-page stream, scatter chart, layout and backend updates
Run question streaming from the chat index on the full page so hot cards and first turns no longer depend on ChartAnswer ref timing. Extract shared runQuestionStream and wire abort with global stop. Add scatter chart support in the frontend chart registry, G2 SSR, and i18n. Extend layout DSL and menus, chat list UX, templates and Oracle examples, datasource CRUD/API, and chat utilities including popular question clustering. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 99482af commit c6f41d7

29 files changed

Lines changed: 1420 additions & 275 deletions

File tree

backend/apps/chat/api/chat.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55

66
import orjson
77
import pandas as pd
8-
from fastapi import APIRouter, HTTPException, Path
8+
from fastapi import APIRouter, HTTPException, Path, Query
99
from fastapi.responses import StreamingResponse
1010
from sqlalchemy import and_, select
1111
from starlette.responses import JSONResponse
1212

1313
from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, \
1414
list_chats, get_chat_with_records, create_chat, rename_chat, \
1515
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
16-
format_json_data, format_json_list_data, get_chart_config, list_recent_questions, get_chat as get_chat_exec, \
16+
format_json_data, format_json_list_data, get_chart_config, list_recent_questions, list_popular_questions, \
17+
get_chat as get_chat_exec, \
1718
rename_chat_with_user, get_chat_log_history, get_chart_data_with_user_live
1819
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
1920
ChatInfo, Chat, ChatFinishStep
@@ -34,6 +35,17 @@ async def chats(session: SessionDep, current_user: CurrentUser):
3435
return list_chats(session, current_user)
3536

3637

38+
@router.get("/popular_questions", summary=f"{PLACEHOLDER_PREFIX}popular_questions_workspace")
39+
async def popular_questions(
40+
session: SessionDep, current_user: CurrentUser, limit: int = Query(8, ge=1, le=50)
41+
):
42+
"""工作空间内提问频次统计(排除首条占位记录)。"""
43+
def inner():
44+
return list_popular_questions(session=session, current_user=current_user, limit=limit)
45+
46+
return await asyncio.to_thread(inner)
47+
48+
3749
@router.get("/{chart_id}", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}get_chat")
3850
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant,
3951
trans: Trans):

backend/apps/chat/curd/chat.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from common.utils.data_format import DataFormat
2121
from common.utils.utils import extract_nested_json, SQLBotLogUtil
2222

23+
from apps.chat.utils.popular_questions_cluster import cluster_questions_for_datasource
24+
2325

2426
def get_chat_record_by_id(session: SessionDep, record_id: int):
2527
record: ChatRecord | None = None
@@ -66,6 +68,58 @@ def list_recent_questions(session: SessionDep, current_user: CurrentUser, dataso
6668
return [record[0] for record in chat_records] if chat_records else []
6769

6870

71+
def list_popular_questions(session: SessionDep, current_user: CurrentUser, limit: int = 8) -> List[Dict[str, Any]]:
72+
"""按数据源 + 语义合并统计热门问题(同一数据源内相近问句合并,非纯字面 group_by)。"""
73+
oid = current_user.oid if current_user.oid is not None else 1
74+
limit = min(max(limit, 1), 50)
75+
cnt = func.count(ChatRecord.id).label('cnt')
76+
rows = (
77+
session.query(Chat.datasource, ChatRecord.question, cnt)
78+
.join(Chat, ChatRecord.chat_id == Chat.id)
79+
.filter(
80+
Chat.oid == oid,
81+
Chat.create_by == current_user.id,
82+
Chat.datasource.isnot(None),
83+
ChatRecord.question.isnot(None),
84+
ChatRecord.question != '',
85+
ChatRecord.first_chat.isnot(True),
86+
)
87+
.group_by(Chat.datasource, ChatRecord.question)
88+
.order_by(desc(cnt))
89+
.limit(400)
90+
.all()
91+
)
92+
by_ds: Dict[Any, List[tuple]] = {}
93+
for ds_id, question, c in rows:
94+
by_ds.setdefault(ds_id, []).append((question, int(c)))
95+
96+
ds_ids = [k for k in by_ds.keys() if k is not None]
97+
id_to_name: Dict[Any, str] = {}
98+
if ds_ids:
99+
ds_rows = session.query(CoreDatasource.id, CoreDatasource.name).filter(
100+
CoreDatasource.id.in_(ds_ids),
101+
CoreDatasource.oid == oid,
102+
).all()
103+
id_to_name = {r[0]: r[1] for r in ds_rows}
104+
105+
flat: List[Dict[str, Any]] = []
106+
for ds_id, weighted in by_ds.items():
107+
if ds_id is None:
108+
continue
109+
for rep_q, total in cluster_questions_for_datasource(weighted):
110+
flat.append(
111+
{
112+
'datasource_id': int(ds_id),
113+
'datasource_name': id_to_name.get(ds_id) or '',
114+
'question': rep_q,
115+
'count': total,
116+
}
117+
)
118+
119+
flat.sort(key=lambda x: (-x['count'], x.get('datasource_name') or ''))
120+
return flat[:limit]
121+
122+
69123
def rename_chat_with_user(session: SessionDep, current_user: CurrentUser, rename_object: RenameChat) -> str:
70124
chat = session.get(Chat, rename_object.id)
71125
if not chat:
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
热门问题:按数据源聚合,并在同一数据源内做语义相近合并(非纯字面 group_by)。
3+
4+
1. 意图桶:库表/数据概览类中文问法合并为同一主题(见 META_OVERVIEW_PATTERN)。
5+
2. 向量聚类:对其余问句用本地中文 embedding 做余弦相似度合并(可选,失败则回退)。
6+
3. 回退:归一化 + difflib 合并相近字面。
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import re
12+
from difflib import SequenceMatcher
13+
from typing import Any, Dict, List, Tuple
14+
15+
import numpy as np
16+
17+
# 表/数据量/有哪些数据 等「元信息」类问题归为一类(用户示例)
18+
META_OVERVIEW_PATTERN = re.compile(
19+
r"(几张表|哪些表|多少张表|有多少表|表.*数据量|数据量.*表|分别.*数据量|数据量.*多大|"
20+
r"哪些数据|有什么数据|有哪些数据|什么数据|库表|schema|多少条数据|统计.*表|表的.*数量)",
21+
re.IGNORECASE,
22+
)
23+
24+
25+
def normalize_question(s: str) -> str:
26+
if not s:
27+
return ""
28+
t = s.strip()
29+
t = re.sub(r"[\s\u3000]+", "", t)
30+
t = re.sub(r"[。..!?!?;;,、]+$", "", t)
31+
return t
32+
33+
34+
def _split_meta_overview(
35+
weighted: List[Tuple[str, int]],
36+
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
37+
meta: List[Tuple[str, int]] = []
38+
rest: List[Tuple[str, int]] = []
39+
for q, c in weighted:
40+
if META_OVERVIEW_PATTERN.search(q):
41+
meta.append((q, c))
42+
else:
43+
rest.append((q, c))
44+
out: List[Tuple[str, int]] = []
45+
if meta:
46+
rep = max(meta, key=lambda x: x[1])[0]
47+
total = sum(c for _, c in meta)
48+
out.append((rep, total))
49+
return out, rest
50+
51+
52+
def _merge_difflib(weighted: List[Tuple[str, int]], threshold: float = 0.78) -> List[Tuple[str, int]]:
53+
if not weighted:
54+
return []
55+
items = sorted(weighted, key=lambda x: -x[1])
56+
clusters: List[Dict[str, Any]] = []
57+
for q, c in items:
58+
nq = normalize_question(q)
59+
best_i = -1
60+
best_r = 0.0
61+
for i, cl in enumerate(clusters):
62+
r = SequenceMatcher(None, nq, cl["norm"]).ratio()
63+
if r >= threshold and r > best_r:
64+
best_r = r
65+
best_i = i
66+
if best_i >= 0:
67+
clusters[best_i]["count"] += c
68+
if c > clusters[best_i].get("max_w", 0):
69+
clusters[best_i]["rep"] = q
70+
clusters[best_i]["max_w"] = c
71+
else:
72+
clusters.append({"rep": q, "count": c, "norm": nq, "max_w": c})
73+
return [(c["rep"], int(c["count"])) for c in clusters]
74+
75+
76+
def _merge_embedding(weighted: List[Tuple[str, int]], threshold: float = 0.76) -> List[Tuple[str, int]]:
77+
if len(weighted) <= 1:
78+
return weighted
79+
try:
80+
from apps.ai_model.embedding import EmbeddingModelCache
81+
82+
texts = [w[0] for w in weighted]
83+
model = EmbeddingModelCache.get_model()
84+
embs = model.embed_documents(texts)
85+
arr = np.array(embs, dtype=np.float32)
86+
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-9
87+
arr = arr / norms
88+
n = len(weighted)
89+
parent = list(range(n))
90+
91+
def find(a: int) -> int:
92+
while parent[a] != a:
93+
parent[a] = parent[parent[a]]
94+
a = parent[a]
95+
return a
96+
97+
def union(a: int, b: int) -> None:
98+
ra, rb = find(a), find(b)
99+
if ra != rb:
100+
parent[rb] = ra
101+
102+
sim = arr @ arr.T
103+
for i in range(n):
104+
for j in range(i + 1, n):
105+
if float(sim[i, j]) >= threshold:
106+
union(i, j)
107+
groups: Dict[int, List[int]] = {}
108+
for i in range(n):
109+
r = find(i)
110+
groups.setdefault(r, []).append(i)
111+
out: List[Tuple[str, int]] = []
112+
for idxs in groups.values():
113+
total = sum(weighted[i][1] for i in idxs)
114+
rep_q = max((weighted[i] for i in idxs), key=lambda x: x[1])[0]
115+
out.append((rep_q, int(total)))
116+
return out
117+
except Exception:
118+
return _merge_difflib(weighted, threshold=0.78)
119+
120+
121+
def cluster_questions_for_datasource(weighted: List[Tuple[str, int]]) -> List[Tuple[str, int]]:
122+
"""同一数据源下多组 (原文, 次数) -> 合并后 (代表问句, 总次数)。"""
123+
if not weighted:
124+
return []
125+
meta_merged, rest = _split_meta_overview(weighted)
126+
if not rest:
127+
return meta_merged
128+
embedded_or_fb = _merge_embedding(rest)
129+
return meta_merged + embedded_or_fb

backend/apps/datasource/api/datasource.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def inner():
231231
try:
232232
return preview(session, current_user, id, data)
233233
except Exception as e:
234+
SQLBotLogUtil.error(f"Preview failed: {e}, try another way")
234235
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
235236
# check ds status
236237
status = check_status(session, trans, ds, True)

backend/apps/datasource/crud/datasource.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
329329

330330
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
331331
sql: str = ""
332-
if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks":
333-
sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}`
332+
if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks" or ds.type == "hive":
333+
sql = f"""SELECT `{"`, `".join(fields)}` FROM `{conf.database}`.`{data.table.table_name}`
334334
{where}
335335
LIMIT 100"""
336336
elif ds.type == "sqlServer":

backend/templates/sql_examples/Oracle.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ template:
88
<step>5. 应用其他规则(引号、别名、格式化等)</step>
99
<step>6. <strong>最终验证:GROUP BY查询的ROWNUM位置是否正确?</strong></step>
1010
<step>7. <strong>强制检查:验证SQL语法是否符合<db-engine>规范</strong></step>
11-
<step>8. 确定图表类型(根据规则选择table/column/bar/line/pie)</step>
11+
<step>8. 确定图表类型(根据规则选择table/column/bar/line/pie/scatter)</step>
1212
<step>9. 确定对话标题</step>
1313
<step>10. 生成JSON结果</step>
1414
<step>11. <strong>强制检查:JSON格式是否正确</strong></step>

0 commit comments

Comments
 (0)