Skip to content

Commit 6fcf8c9

Browse files
yymclaude
authored andcommitted
feat: Add SQLite and Hive database support with sample data for AI
- Add SQLite support to data sources (file-based connection) - Add Hive support to data sources - Add SSL toggle option for MySQL and Doris databases - Add sample data (3 rows JSON format) to help AI better understand table schema - Fix SQLite schema table name issue (empty schema prefix) - Add SQLite.yaml template for SQL generation - Add Hive to haveSchema list in frontend Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3877312 commit 6fcf8c9

10 files changed

Lines changed: 343 additions & 19 deletions

File tree

backend/apps/chat/models/chat_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class AiModelQuestion(BaseModel):
229229
custom_prompt: str = ""
230230
error_msg: str = ""
231231
regenerate_record_id: Optional[int] = None
232+
sample_data: str = ""
232233

233234
def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
234235
templates: dict[str, str] = {}
@@ -256,7 +257,7 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
256257
example_answer_1=_example_answer_1,
257258
example_answer_2=_example_answer_2,
258259
example_answer_3=_example_answer_3)
259-
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema)
260+
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema, sample_data=self.sample_data)
260261

261262
if self.terminologies:
262263
templates['terminologies'] = _base_template['generate_terminologies_info'].format(

backend/apps/chat/task/llm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
3737
ChatFinishStep, AxisObj, SystemPromptMessage, HumanPromptMessage, AIPromptMessage
3838
from apps.data_training.curd.data_training import get_training_template
39-
from apps.datasource.crud.datasource import get_table_schema
39+
from apps.datasource.crud.datasource import get_table_schema, get_tables_sample_data
4040
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
4141
from apps.datasource.embedding.ds_embedding import get_ds_embedding
4242
from apps.datasource.models.datasource import CoreDatasource
@@ -384,6 +384,13 @@ def choose_table_schema(self, _session: Session):
384384
ds=self.ds,
385385
question=self.chat_question.question)
386386

387+
# Get sample data for all tables
388+
if not self.out_ds_instance:
389+
self.chat_question.sample_data = get_tables_sample_data(
390+
session=_session,
391+
current_user=self.current_user,
392+
ds=self.ds)
393+
387394
self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
388395
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
389396
full_message=self.chat_question.db_schema)
@@ -505,6 +512,13 @@ def generate_recommend_questions_task(self, _session: Session):
505512
question=self.chat_question.question,
506513
embedding=False)
507514

515+
# Get sample data for all tables
516+
if not self.out_ds_instance:
517+
self.chat_question.sample_data = get_tables_sample_data(
518+
session=_session,
519+
current_user=self.current_user,
520+
ds=self.ds)
521+
508522
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
509523
guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
510524

backend/apps/datasource/crud/datasource.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from common.core.config import settings
1818
from common.core.deps import SessionDep, CurrentUser, Trans
1919
from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings
20-
from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra
20+
from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra, equals_ignore_case
2121
from common.core.sqlbot_cache import cache, clear_cache
2222
from .table import get_tables_by_ds_id
2323
from ..crud.field import delete_field_by_ds_id, update_field
@@ -357,12 +357,16 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
357357
{where}
358358
LIMIT 100"""
359359
elif ds.type == "dm":
360-
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
361-
{where}
360+
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
361+
{where}
362362
LIMIT 100"""
363363
elif ds.type == "es":
364-
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
365-
{where}
364+
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
365+
{where}
366+
LIMIT 100"""
367+
elif ds.type == "sqlite":
368+
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
369+
{where}
366370
LIMIT 100"""
367371
return exec_sql(ds, sql, True)
368372

@@ -430,6 +434,79 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
430434
return _list
431435

432436

437+
def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) -> str:
438+
"""Get 3 sample rows from a table in JSON format to help AI understand the data"""
439+
if not fields:
440+
return ""
441+
442+
db = DB.get_db(ds.type)
443+
# Get prefix/suffix for identifier quoting
444+
prefix = db.prefix if hasattr(db, 'prefix') else '"'
445+
suffix = db.suffix if hasattr(db, 'suffix') else '"'
446+
447+
# Build field list with proper quoting
448+
field_names = []
449+
for field in fields[:10]: # Limit to first 10 fields to avoid too wide results
450+
field_name = f"{prefix}{field.field_name}{suffix}"
451+
field_names.append(field_name)
452+
453+
# Build LIMIT query based on database type
454+
if equals_ignore_case(ds.type, "sqlServer"):
455+
query = f"SELECT TOP 3 {','.join(field_names)} FROM {prefix}{table_name}{suffix}"
456+
elif equals_ignore_case(ds.type, "ck"):
457+
query = f"SELECT {','.join(field_names)} FROM {table_name} LIMIT 3"
458+
elif equals_ignore_case(ds.type, "hive"):
459+
query = f"SELECT {','.join(field_names)} FROM {table_name} LIMIT 3"
460+
elif equals_ignore_case(ds.type, "oracle"):
461+
query = f"SELECT {','.join(field_names)} FROM \"{table_name}\" WHERE ROWNUM <= 3"
462+
elif equals_ignore_case(ds.type, "dm"):
463+
query = f"SELECT {','.join(field_names)} FROM \"{table_name}\" WHERE ROWNUM <= 3"
464+
else:
465+
query = f"SELECT {','.join(field_names)} FROM {prefix}{table_name}{suffix} LIMIT 3"
466+
467+
try:
468+
result = exec_sql(ds=ds, sql=query, origin_column=True)
469+
if result and result.get('data') and len(result['data']) > 0:
470+
import json
471+
# Truncate long string values for readability
472+
json_rows = []
473+
for row in result['data'][:3]:
474+
truncated_row = {}
475+
for key, value in row.items():
476+
if value is None:
477+
truncated_row[key] = None
478+
elif isinstance(value, str):
479+
# Truncate long strings
480+
if len(value) > 100:
481+
value = value[:100] + '...'
482+
truncated_row[key] = value.replace('\n', ' ').replace('\r', ' ')
483+
else:
484+
truncated_row[key] = value
485+
json_rows.append(truncated_row)
486+
return json.dumps(json_rows, ensure_ascii=False, indent=2)
487+
except Exception:
488+
pass
489+
return ""
490+
491+
492+
def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource,
493+
table_list: list[str] = None) -> str:
494+
"""Get sample data (3 rows) for all tables to help AI understand the data"""
495+
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
496+
if len(table_objs) == 0:
497+
return ""
498+
499+
sample_data_parts = []
500+
for obj in table_objs:
501+
if table_list is not None and obj.table.table_name not in table_list:
502+
continue
503+
if obj.fields:
504+
sample = get_table_sample_data(ds, obj.table.table_name, obj.fields)
505+
if sample:
506+
sample_data_parts.append(f"# Table: {obj.table.table_name}\n{sample}")
507+
return "\n".join(sample_data_parts)
508+
509+
433510
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
434511
embedding: bool = True, table_list: list[str] = None) -> str:
435512
schema_str = ""
@@ -446,7 +523,8 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
446523
continue
447524

448525
schema_table = ''
449-
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
526+
no_schema_types = ["mysql", "es", "sqlite", "hive", "doris", "starrocks"]
527+
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type not in no_schema_types and db_name else f"# Table: {obj.table.table_name}"
450528
table_comment = ''
451529
if obj.table.custom_comment:
452530
table_comment = obj.table.custom_comment.strip()

backend/apps/db/constant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class DB(Enum):
2828
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle', [])
2929
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', [])
3030
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks', [])
31+
sqlite = ('sqlite', 'SQLite', '"', '"', ConnectType.sqlalchemy, 'SQLite', [])
32+
hive = ('hive', 'Apache Hive', '"', '"', ConnectType.py_driver, 'Hive', [])
3133

3234
def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str,
3335
illegalParams: List[str]):

0 commit comments

Comments
 (0)