1717from common .core .config import settings
1818from common .core .deps import SessionDep , CurrentUser , Trans
1919from common .utils .embedding_threads import run_save_table_embeddings , run_save_ds_embeddings
20- from common .utils .utils import SQLBotLogUtil , deepcopy_ignore_extra
20+ from common .utils .utils import SQLBotLogUtil , deepcopy_ignore_extra , equals_ignore_case
2121from common .core .sqlbot_cache import cache , clear_cache
2222from .table import get_tables_by_ds_id
2323from ..crud .field import delete_field_by_ds_id , update_field
@@ -357,12 +357,16 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
357357 { where }
358358 LIMIT 100"""
359359 elif ds .type == "dm" :
360- sql = f"""SELECT "{ '", "' .join (fields )} " FROM "{ conf .dbSchema } "."{ data .table .table_name } "
361- { where }
360+ sql = f"""SELECT "{ '", "' .join (fields )} " FROM "{ conf .dbSchema } "."{ data .table .table_name } "
361+ { where }
362362 LIMIT 100"""
363363 elif ds .type == "es" :
364- sql = f"""SELECT "{ '", "' .join (fields )} " FROM "{ data .table .table_name } "
365- { where }
364+ sql = f"""SELECT "{ '", "' .join (fields )} " FROM "{ data .table .table_name } "
365+ { where }
366+ LIMIT 100"""
367+ elif ds .type == "sqlite" :
368+ sql = f"""SELECT "{ '", "' .join (fields )} " FROM "{ data .table .table_name } "
369+ { where }
366370 LIMIT 100"""
367371 return exec_sql (ds , sql , True )
368372
@@ -430,6 +434,79 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
430434 return _list
431435
432436
437+ def get_table_sample_data (ds : CoreDatasource , table_name : str , fields : list ) -> str :
438+ """Get 3 sample rows from a table in JSON format to help AI understand the data"""
439+ if not fields :
440+ return ""
441+
442+ db = DB .get_db (ds .type )
443+ # Get prefix/suffix for identifier quoting
444+ prefix = db .prefix if hasattr (db , 'prefix' ) else '"'
445+ suffix = db .suffix if hasattr (db , 'suffix' ) else '"'
446+
447+ # Build field list with proper quoting
448+ field_names = []
449+ for field in fields [:10 ]: # Limit to first 10 fields to avoid too wide results
450+ field_name = f"{ prefix } { field .field_name } { suffix } "
451+ field_names .append (field_name )
452+
453+ # Build LIMIT query based on database type
454+ if equals_ignore_case (ds .type , "sqlServer" ):
455+ query = f"SELECT TOP 3 { ',' .join (field_names )} FROM { prefix } { table_name } { suffix } "
456+ elif equals_ignore_case (ds .type , "ck" ):
457+ query = f"SELECT { ',' .join (field_names )} FROM { table_name } LIMIT 3"
458+ elif equals_ignore_case (ds .type , "hive" ):
459+ query = f"SELECT { ',' .join (field_names )} FROM { table_name } LIMIT 3"
460+ elif equals_ignore_case (ds .type , "oracle" ):
461+ query = f"SELECT { ',' .join (field_names )} FROM \" { table_name } \" WHERE ROWNUM <= 3"
462+ elif equals_ignore_case (ds .type , "dm" ):
463+ query = f"SELECT { ',' .join (field_names )} FROM \" { table_name } \" WHERE ROWNUM <= 3"
464+ else :
465+ query = f"SELECT { ',' .join (field_names )} FROM { prefix } { table_name } { suffix } LIMIT 3"
466+
467+ try :
468+ result = exec_sql (ds = ds , sql = query , origin_column = True )
469+ if result and result .get ('data' ) and len (result ['data' ]) > 0 :
470+ import json
471+ # Truncate long string values for readability
472+ json_rows = []
473+ for row in result ['data' ][:3 ]:
474+ truncated_row = {}
475+ for key , value in row .items ():
476+ if value is None :
477+ truncated_row [key ] = None
478+ elif isinstance (value , str ):
479+ # Truncate long strings
480+ if len (value ) > 100 :
481+ value = value [:100 ] + '...'
482+ truncated_row [key ] = value .replace ('\n ' , ' ' ).replace ('\r ' , ' ' )
483+ else :
484+ truncated_row [key ] = value
485+ json_rows .append (truncated_row )
486+ return json .dumps (json_rows , ensure_ascii = False , indent = 2 )
487+ except Exception :
488+ pass
489+ return ""
490+
491+
492+ def get_tables_sample_data (session : SessionDep , current_user : CurrentUser , ds : CoreDatasource ,
493+ table_list : list [str ] = None ) -> str :
494+ """Get sample data (3 rows) for all tables to help AI understand the data"""
495+ table_objs = get_table_obj_by_ds (session = session , current_user = current_user , ds = ds )
496+ if len (table_objs ) == 0 :
497+ return ""
498+
499+ sample_data_parts = []
500+ for obj in table_objs :
501+ if table_list is not None and obj .table .table_name not in table_list :
502+ continue
503+ if obj .fields :
504+ sample = get_table_sample_data (ds , obj .table .table_name , obj .fields )
505+ if sample :
506+ sample_data_parts .append (f"# Table: { obj .table .table_name } \n { sample } " )
507+ return "\n " .join (sample_data_parts )
508+
509+
433510def get_table_schema (session : SessionDep , current_user : CurrentUser , ds : CoreDatasource , question : str ,
434511 embedding : bool = True , table_list : list [str ] = None ) -> str :
435512 schema_str = ""
@@ -446,7 +523,8 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
446523 continue
447524
448525 schema_table = ''
449- schema_table += f"# Table: { db_name } .{ obj .table .table_name } " if ds .type != "mysql" and ds .type != "es" else f"# Table: { obj .table .table_name } "
526+ no_schema_types = ["mysql" , "es" , "sqlite" , "hive" , "doris" , "starrocks" ]
527+ schema_table += f"# Table: { db_name } .{ obj .table .table_name } " if ds .type not in no_schema_types and db_name else f"# Table: { obj .table .table_name } "
450528 table_comment = ''
451529 if obj .table .custom_comment :
452530 table_comment = obj .table .custom_comment .strip ()
0 commit comments