diff --git a/changelog.md b/changelog.md index de9ef78e..0272a7ad 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ Features --------- * Respond to `-h` alone with the helpdoc. * Allow `--hostname` as an alias for `--host`. +* Suggest tables with foreign key relationships for JOIN and ON (#975) * Deprecate `$DSN` environment variable in favor of `$MYSQL_DSN`. diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index 38b547b2..94e6429c 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -132,6 +132,11 @@ def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_columns(table_columns_dbresult, kind="tables") +@refresher("foreign_keys") +def refresh_foreign_keys(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_foreign_keys(executor.foreign_keys()) + + @refresher("enum_values") def refresh_enum_values(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_enum_values(executor.enum_values()) diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 845b4d0e..cc8f41a7 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -476,10 +476,14 @@ def suggest_based_on_last_token( or (token_v == "like" and re.match(r'^\s*create\s+table\s', full_text, re.IGNORECASE)) ): schema = (identifier and identifier.get_parent_name()) or [] + is_join = token_v.endswith("join") # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified - suggest = [{"type": "table", "schema": schema}] + table_suggestion: dict[str, Any] = {"type": "table", "schema": schema} + if is_join: + table_suggestion["join"] = True + suggest = [table_suggestion] if not schema: # Suggest schemas @@ -516,7 +520,7 @@ def suggest_based_on_last_token( # ON # Use table alias if there is one, otherwise the table name aliases = [alias or table for (schema, table, alias) in tables] - suggest = [{"type": "alias", "aliases": aliases}] + suggest = [{"type": "fk_join", "tables": tables}, {"type": "alias", "aliases": aliases}] # The lists of 'aliases' could be empty if we're trying to complete # a GRANT query. eg: GRANT SELECT, INSERT ON diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index ba897398..44e1bcb2 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -13,7 +13,7 @@ from mycli.packages.completion_engine import is_inside_quotes, suggest_type from mycli.packages.filepaths import complete_path, parse_path, suggest_path -from mycli.packages.parseutils import extract_columns_from_select, last_word +from mycli.packages.parseutils import extract_columns_from_select, extract_tables, last_word from mycli.packages.special import llm from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS @@ -1052,6 +1052,51 @@ def extend_enum_values(self, enum_data: Iterable[tuple[str, str, list[str]]]) -> table_meta = metadata[self.dbname].setdefault(relname_escaped, {}) table_meta[column_escaped] = values + def extend_foreign_keys(self, fk_data: Iterable[tuple[str, str, str, str]]) -> None: + """Extend FK metadata. + + :param fk_data: iterable of (table_name, column_name, referenced_table_name, referenced_column_name) + """ + metadata = self.dbmetadata["foreign_keys"] + schema_meta = metadata.setdefault(self.dbname, {}) + schema_meta.setdefault("tables", {}) + schema_meta.setdefault("relations", []) + for table, col, ref_table, ref_col in fk_data: + table = self.escape_name(table) + col = self.escape_name(col) + ref_table = self.escape_name(ref_table) + ref_col = self.escape_name(ref_col) + schema_meta["tables"].setdefault(table, set()).add(ref_table) + schema_meta["tables"].setdefault(ref_table, set()).add(table) + schema_meta["relations"].append((table, col, ref_table, ref_col)) + + def _fk_join_conditions(self, tables: list[tuple[str | None, str, str]]) -> list[str]: + """Return FK-based join condition strings for the tables currently in the query. + + For each FK relation where both the FK table and the referenced table appear in + *tables*, yields a string like ``alias1.col = alias2.ref_col`` (using the alias + when one exists, otherwise the table name). + """ + schema_meta = self.dbmetadata["foreign_keys"].get(self.dbname, {}) + relations = schema_meta.get("relations", []) + + # Map escaped table name -> alias (or table name when no alias). + # Skip tables from a different schema; we only have FK metadata for the current db. + alias_map: dict[str, str] = {} + for tbl_schema, tbl, alias in tables: + if tbl_schema and tbl_schema != self.dbname: + continue + escaped = self.escape_name(tbl) + alias_map[escaped] = alias or tbl + + conditions: list[str] = [] + for fk_table, fk_col, ref_table, ref_col in relations: + lhs = alias_map.get(fk_table) + rhs = alias_map.get(ref_table) + if lhs and rhs: + conditions.append(f"{lhs}.{fk_col} = {rhs}.{ref_col}") + return conditions + def extend_functions(self, func_data: list[str] | Generator[tuple[str, str]], builtin: bool = False) -> None: # if 'builtin' is set this is extending the list of builtin functions if builtin: @@ -1124,6 +1169,7 @@ def reset_completions(self) -> None: "functions": {}, "procedures": {}, "enum_values": {}, + "foreign_keys": {}, } self.all_completions = set(self.keywords + self.functions) @@ -1366,12 +1412,39 @@ def get_completions( tables = self.populate_schema_objects(suggestion["schema"], "tables", columns) else: tables = self.populate_schema_objects(suggestion["schema"], "tables") - tables_m = self.find_matches( - word_before_cursor, - tables, - text_before_cursor=document.text_before_cursor, - ) - completions.extend([(*x, rank) for x in tables_m]) + + if suggestion.get("join"): + # For JOINs, suggest FK-related tables first (lower rank = higher priority) + current_tables = extract_tables(document.text) + fk_map = self.dbmetadata["foreign_keys"].get(self.dbname, {}).get("tables", {}) + fk_related: set[str] = set() + for tbl_schema, tbl, _alias in current_tables: + # Skip cross-schema tables; FK metadata is only for the current db + if tbl_schema and tbl_schema != self.dbname: + continue + escaped = self.escape_name(tbl) + fk_related.update(fk_map.get(escaped, set())) + fk_tables = [t for t in tables if t in fk_related] + other_tables = [t for t in tables if t not in fk_related] + fk_tables_m = self.find_matches( + word_before_cursor, + fk_tables, + text_before_cursor=document.text_before_cursor, + ) + other_tables_m = self.find_matches( + word_before_cursor, + other_tables, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in fk_tables_m]) + completions.extend([(*x, rank + 1) for x in other_tables_m]) + else: + tables_m = self.find_matches( + word_before_cursor, + tables, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in tables_m]) elif suggestion["type"] == "view": views = self.populate_schema_objects(suggestion["schema"], "views") @@ -1382,6 +1455,15 @@ def get_completions( ) completions.extend([(*x, rank) for x in views_m]) + elif suggestion["type"] == "fk_join": + fk_conditions = self._fk_join_conditions(suggestion["tables"]) + fk_conditions_m = self.find_matches( + word_before_cursor, + fk_conditions, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in fk_conditions_m]) + elif suggestion["type"] == "alias": aliases = suggestion["aliases"] aliases_m = self.find_matches( diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 16b0f04d..d9fa108e 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -115,6 +115,10 @@ class SQLExecute: where table_schema = %s and data_type = 'enum' order by table_name,ordinal_position""" + foreign_keys_query = """SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME + FROM information_schema.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL""" + now_query = """SELECT NOW()""" @staticmethod @@ -440,6 +444,17 @@ def enum_values(self) -> Generator[tuple[str, str, list[str]], None, None]: if values: yield (table_name, column_name, values) + def foreign_keys(self) -> Generator[tuple[str, str, str, str], None, None]: + """Yields (table_name, column_name, referenced_table_name, referenced_column_name) tuples""" + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Foreign Keys Query. sql: %r", self.foreign_keys_query) + try: + cur.execute(self.foreign_keys_query, (self.dbname,)) + yield from cur + except Exception as e: + _logger.error('No foreign key completions due to %r', e) + def databases(self) -> list[str]: assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: diff --git a/test/pytests/test_completion_engine.py b/test/pytests/test_completion_engine.py index 582ea37c..e413ab5d 100644 --- a/test/pytests/test_completion_engine.py +++ b/test/pytests/test_completion_engine.py @@ -167,7 +167,6 @@ def test_select_suggests_cols_and_funcs(): "DESCRIBE ", "DESC ", "EXPLAIN ", - "SELECT * FROM foo JOIN ", ], ) def test_expression_suggests_tables_views_and_schemas(expression): @@ -179,6 +178,16 @@ def test_expression_suggests_tables_views_and_schemas(expression): ]) +def test_join_expression_suggests_tables_views_and_schemas(): + expression = "SELECT * FROM foo JOIN " + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": [], "join": True}, + {"type": "view", "schema": []}, + {"type": "database"}, + ]) + + @pytest.mark.parametrize( "expression", [ @@ -189,7 +198,6 @@ def test_expression_suggests_tables_views_and_schemas(expression): "DESCRIBE sch.", "DESC sch.", "EXPLAIN sch.", - "SELECT * FROM foo JOIN sch.", ], ) def test_expression_suggests_qualified_tables_views_and_schemas(expression): @@ -200,6 +208,15 @@ def test_expression_suggests_qualified_tables_views_and_schemas(expression): ]) +def test_join_expression_suggests_qualified_tables_views_and_schemas(): + expression = "SELECT * FROM foo JOIN sch." + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": "sch", "join": True}, + {"type": "view", "schema": "sch"}, + ]) + + def test_truncate_suggests_tables_and_schemas(): suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") assert sorted_dicts(suggestions) == sorted_dicts([ @@ -395,7 +412,7 @@ def test_join_suggests_tables_and_schemas(tbl_alias, join_type): suggestion = suggest_type(text, text) assert sorted_dicts(suggestion) == sorted_dicts([ {"type": "database"}, - {"type": "table", "schema": []}, + {"type": "table", "schema": [], "join": True}, {"type": "view", "schema": []}, ]) @@ -445,7 +462,10 @@ def test_join_alias_dot_suggests_cols2(sql): ) def test_on_suggests_aliases(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", "a"), (None, "bcd", "b")]}, + {"type": "alias", "aliases": ["a", "b"]}, + ] @pytest.mark.parametrize( @@ -457,7 +477,10 @@ def test_on_suggests_aliases(sql): ) def test_on_suggests_tables(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", None), (None, "bcd", None)]}, + {"type": "alias", "aliases": ["abc", "bcd"]}, + ] @pytest.mark.parametrize( @@ -469,7 +492,10 @@ def test_on_suggests_tables(sql): ) def test_on_suggests_aliases_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", "a"), (None, "bcd", "b")]}, + {"type": "alias", "aliases": ["a", "b"]}, + ] @pytest.mark.parametrize( @@ -481,7 +507,10 @@ def test_on_suggests_aliases_right_side(sql): ) def test_on_suggests_tables_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", None), (None, "bcd", None)]}, + {"type": "alias", "aliases": ["abc", "bcd"]}, + ] @pytest.mark.parametrize("col_list", ["", "col1, "]) @@ -610,7 +639,7 @@ def test_cross_join(): suggestions = suggest_type(text, text) assert sorted_dicts(suggestions) == sorted_dicts([ {"type": "database"}, - {"type": "table", "schema": []}, + {"type": "table", "schema": [], "join": True}, {"type": "view", "schema": []}, ]) diff --git a/test/pytests/test_completion_refresher.py b/test/pytests/test_completion_refresher.py index e7ed35b2..bc3cedc5 100644 --- a/test/pytests/test_completion_refresher.py +++ b/test/pytests/test_completion_refresher.py @@ -26,6 +26,7 @@ def test_ctor(refresher): "databases", "schemata", "tables", + "foreign_keys", "enum_values", "users", "functions", diff --git a/test/pytests/test_smart_completion_public_schema_only.py b/test/pytests/test_smart_completion_public_schema_only.py index fce8bf9f..404c2147 100644 --- a/test/pytests/test_smart_completion_public_schema_only.py +++ b/test/pytests/test_smart_completion_public_schema_only.py @@ -968,3 +968,170 @@ def test_backticked_no_completion_spaces(completer, complete_event): position = len(text) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == [] + + +# Foreign key completion tests +@pytest.fixture +def fk_completer(): + """SQLCompleter with tables and a FK relationship. + + Schema: + orders (id, user_id, ordered_date, status) FK: user_id -> users.id + users (id, email, first_name) + tags (id, name) no FK + """ + import mycli.packages.special.main as special + import mycli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter(smart_completion=True) + + tables = [("orders",), ("users",), ("tags",)] + columns = [ + ("orders", "id"), + ("orders", "user_id"), + ("orders", "ordered_date"), + ("orders", "status"), + ("users", "id"), + ("users", "email"), + ("users", "first_name"), + ("tags", "id"), + ("tags", "name"), + ] + fk_data = [("orders", "user_id", "users", "id")] + + comp.extend_schemata("test") + comp.extend_database_names(["test"]) + comp.set_dbname("test") + comp.extend_relations(tables, kind="tables") + comp.extend_columns(columns, kind="tables") + comp.extend_foreign_keys(fk_data) + comp.extend_special_commands(special.COMMANDS) + + return comp + + +def test_extend_foreign_keys_stores_relation(fk_completer): + relations = fk_completer.dbmetadata["foreign_keys"]["test"]["relations"] + assert ("orders", "user_id", "users", "id") in relations + + +def test_extend_foreign_keys_stores_bidirectional_table_map(fk_completer): + tables_map = fk_completer.dbmetadata["foreign_keys"]["test"]["tables"] + assert "users" in tables_map["orders"] + assert "orders" in tables_map["users"] + + +def test_extend_foreign_keys_unrelated_table_absent_from_map(fk_completer): + tables_map = fk_completer.dbmetadata["foreign_keys"]["test"]["tables"] + assert "tags" not in tables_map + + +def test_fk_join_conditions_with_aliases(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", "o"), (None, "users", "u")]) + assert conditions == ["o.user_id = u.id"] + + +def test_fk_join_conditions_without_aliases(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", None), (None, "users", None)]) + assert conditions == ["orders.user_id = users.id"] + + +def test_fk_join_conditions_single_table_yields_nothing(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", "o")]) + assert conditions == [] + + +def test_fk_join_conditions_unrelated_tables_yields_nothing(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", "o"), (None, "tags", "t")]) + assert conditions == [] + + +def test_join_suggests_fk_table_before_unrelated(fk_completer, complete_event): + text = "SELECT * FROM orders JOIN " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "users" in result + assert "tags" in result + assert result.index("users") < result.index("tags") + + +def test_join_fk_lookup_is_bidirectional(fk_completer, complete_event): + text = "SELECT * FROM users JOIN " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders" in result + assert "tags" in result + assert result.index("orders") < result.index("tags") + + +def test_join_unrelated_table_still_suggests_all_tables(fk_completer, complete_event): + text = "SELECT * FROM tags JOIN " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders" in result + assert "users" in result + + +def test_on_suggests_fk_condition_with_aliases(fk_completer, complete_event): + text = "SELECT * FROM orders o JOIN users u ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "o.user_id = u.id" in result + + +def test_on_suggests_fk_condition_without_aliases(fk_completer, complete_event): + text = "SELECT * FROM orders JOIN users ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders.user_id = users.id" in result + + +def test_on_fk_condition_appears_before_aliases(fk_completer, complete_event): + text = "SELECT * FROM orders o JOIN users u ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert result.index("o.user_id = u.id") < result.index("o") + + +def test_on_no_fk_condition_for_unrelated_join(fk_completer, complete_event): + text = "SELECT * FROM orders o JOIN tags t ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert not any("=" in r for r in result) + assert "o" in result + assert "t" in result + + +def test_on_partial_text_filters_fk_condition(fk_completer, complete_event): + text = "SELECT * FROM orders JOIN users ON ord" + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders.user_id = users.id" in result + + +def test_fk_reserved_column_names_are_escaped(): + """FK columns that are reserved words or need quoting must be backtick-escaped.""" + import mycli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter(smart_completion=True) + comp.extend_schemata("test") + comp.set_dbname("test") + comp.extend_foreign_keys([("orders", "order", "users", "select")]) + + relations = comp.dbmetadata["foreign_keys"]["test"]["relations"] + assert ("orders", "`order`", "users", "`select`") in relations + + conditions = comp._fk_join_conditions([(None, "orders", "o"), (None, "users", "u")]) + assert conditions == ["o.`order` = u.`select`"] + + +def test_fk_conditions_ignore_cross_schema_tables(fk_completer): + """Tables qualified with a foreign schema are excluded from FK condition generation.""" + tables = [("other_db", "orders", "o"), (None, "users", "u")] + conditions = fk_completer._fk_join_conditions(tables) + assert conditions == [] + + +def test_join_priority_ignores_cross_schema_table(fk_completer, complete_event): + """Schema-qualified tables in FROM do not trigger FK priority using current-db metadata.""" + text = "SELECT * FROM other_db.orders JOIN " + result_cross_schema = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + # A table with no FK relationships at all should give the same ordering, + # confirming that no FK priority was applied for the cross-schema table. + text_no_fk = "SELECT * FROM tags JOIN " + result_no_fk = [ + c.text for c in fk_completer.get_completions(Document(text=text_no_fk, cursor_position=len(text_no_fk)), complete_event) + ] + assert result_cross_schema == result_no_fk