From e74e1441fab9631c995d6429bf703adb41ba710a Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Wed, 13 May 2026 18:14:20 +0800 Subject: [PATCH 01/10] refactor(generic): Add more precise type widening --- .../analyzer/lua/for_range_stat.rs | 12 +- .../src/compilation/analyzer/lua/stats.rs | 41 +- .../src/compilation/test/generic_test.rs | 152 ++++++- .../src/compilation/test/unpack_test.rs | 38 ++ .../src/semantic/generic/call_constraint.rs | 10 +- .../instantiate_type/complete_generic_args.rs | 4 +- ..._generic.rs => infer_call_func_generic.rs} | 137 +++--- .../instantiate_type/inference_widening.rs | 405 ++++++++++++++++++ .../instantiate_conditional_generic.rs | 6 +- .../instantiate_special_generic.rs | 48 ++- .../semantic/generic/instantiate_type/mod.rs | 149 ++----- .../src/semantic/generic/mod.rs | 1 - .../src/semantic/generic/test.rs | 33 ++ .../src/semantic/generic/tpl_context.rs | 68 ++- .../tpl_pattern/generic_tpl_pattern.rs | 5 +- .../src/semantic/generic/tpl_pattern/mod.rs | 140 +++--- .../src/semantic/generic/type_substitutor.rs | 398 +++++++++++------ .../src/semantic/infer/infer_call/mod.rs | 48 ++- .../infer/narrow/condition_flow/call_flow.rs | 4 +- .../narrow/condition_flow/correlated_flow.rs | 4 +- .../collect_callable_overloads.rs | 85 ++++ .../src/semantic/overload_resolve/mod.rs | 6 +- .../src/handlers/definition/goto_function.rs | 4 +- .../src/handlers/hover/function/mod.rs | 12 +- .../src/handlers/hover/hover_builder.rs | 2 +- .../src/handlers/test/semantic_token_test.rs | 28 +- 26 files changed, 1390 insertions(+), 450 deletions(-) rename crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/{instantiate_func_generic.rs => infer_call_func_generic.rs} (85%) create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs create mode 100644 crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index d6605830b..8a3565802 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -3,7 +3,7 @@ use emmylua_parser::{LuaAstToken, LuaExpr, LuaForRangeStat}; use crate::{ DbIndex, InferFailReason, LuaDeclId, LuaInferCache, LuaOperatorMetaMethod, LuaType, LuaTypeCache, TplContext, TypeOps, TypeSubstitutor, VariadicType, - compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_doc_function, + compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_type_generic, tpl_pattern_match_args, }; @@ -145,22 +145,18 @@ pub fn infer_for_range_iter_expr_func( return Ok(doc_function.get_variadic_ret()); }; let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: None, - }; let params = doc_function .get_params() .iter() .map(|(_, opt_ty)| opt_ty.clone().unwrap_or(LuaType::Any)) .collect::>(); + let mut context = TplContext::new(db, cache, &mut substitutor, None); tpl_pattern_match_args(&mut context, ¶ms, &[status_param])?; + let doc_function_ty = LuaType::DocFunction(doc_function.clone()); let instantiate_func = if let LuaType::DocFunction(f) = - instantiate_doc_function(db, &doc_function, &substitutor) + instantiate_type_generic(db, &doc_function_ty, &substitutor) { f } else { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 34aaf8186..3bb9875bf 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -47,6 +47,8 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) break; }; + pre_analyze_call_arg_table_fields(analyzer, &expr); + match analyzer.infer_expr(&expr) { Ok(expr_type) => { let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); @@ -316,6 +318,8 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta continue; } + pre_analyze_call_arg_table_fields(analyzer, expr); + let expr_type = match analyzer.infer_expr(expr) { Ok(expr_type) => expr_type.get_result_slot_type(0).unwrap_or(expr_type), Err(InferFailReason::None) => LuaType::Unknown, @@ -536,8 +540,17 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> } } - let value_expr = field.get_value_expr()?; let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); + if analyzer + .db + .get_type_index() + .get_type_cache(&member_id.into()) + .is_some() + { + return Some(()); + } + let value_expr = field.get_value_expr()?; + let value_type = match analyzer.infer_expr(&value_expr.clone()) { Ok(value_type) => match value_type { LuaType::Def(ref_id) => LuaType::Ref(ref_id), @@ -627,3 +640,29 @@ fn get_delayed_definition_decl_id( } Some(decl_id) } + +fn pre_analyze_call_arg_table_fields(analyzer: &mut LuaAnalyzer, expr: &LuaExpr) { + let LuaExpr::CallExpr(call_expr) = expr else { + return; + }; + let Some(args_list) = call_expr.get_args_list() else { + return; + }; + + for arg in args_list.get_args() { + pre_analyze_table_expr_fields(analyzer, arg); + } +} + +fn pre_analyze_table_expr_fields(analyzer: &mut LuaAnalyzer, expr: LuaExpr) { + let LuaExpr::TableExpr(table_expr) = expr else { + return; + }; + + for field in table_expr.get_fields() { + analyze_table_field(analyzer, field.clone()); + if let Some(value_expr) = field.get_value_expr() { + pre_analyze_table_expr_fields(analyzer, value_expr); + } + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 3a1b462b9..7aeea9050 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -494,9 +494,12 @@ mod test { A, B, C = f1(1, "2", true) "#, )); - assert_eq!(ws.expr_ty("A"), ws.ty("integer")); - assert_eq!(ws.expr_ty("B"), ws.ty("string")); - assert_eq!(ws.expr_ty("C"), ws.ty("boolean")); + let a_ty = ws.expr_ty("A"); + let b_ty = ws.expr_ty("B"); + let c_ty = ws.expr_ty("C"); + assert_eq!(ws.humanize_type(a_ty), "1"); + assert_eq!(ws.humanize_type(b_ty), "\"2\""); + assert_eq!(ws.humanize_type(c_ty), "true"); } { ws.def( @@ -533,7 +536,8 @@ mod test { G, H = f3(1, "2") "#, )); - assert_eq!(ws.expr_ty("G"), ws.ty("integer")); + let g_ty = ws.expr_ty("G"); + assert_eq!(ws.humanize_type(g_ty), "1"); assert_eq!(ws.expr_ty("H"), ws.ty("any")); } @@ -681,7 +685,7 @@ mod test { "#, )); let result_ty = ws.expr_ty("result"); - assert_eq!(ws.humanize_type(result_ty), "string"); + assert_eq!(ws.humanize_type(result_ty), "\"\""); } #[test] @@ -699,7 +703,7 @@ mod test { ); let result_ty = ws.expr_ty("result"); - assert_eq!(ws.humanize_type(result_ty), "integer"); + assert_eq!(ws.humanize_type(result_ty), "1"); } #[test] @@ -1075,7 +1079,7 @@ mod test { let explicit_result = ws.expr_ty("ExplicitResult"); assert_eq!(ws.humanize_type(explicit_result), "number"); let inferred_result = ws.expr_ty("InferredResult"); - assert_eq!(ws.humanize_type(inferred_result), "integer"); + assert_eq!(ws.humanize_type(inferred_result), "1"); } #[test] @@ -1217,7 +1221,7 @@ mod test { } #[test] - fn test_constant_decay() { + fn test_plain_tpl_literal_key_inference_widens_through_finalize() { let mut ws = VirtualWorkspace::new(); ws.def( r#" @@ -1250,6 +1254,138 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "integer"); } + #[test] + fn test_const_tpl_candidate_preserves_literal_through_plain_return() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function keep_const(value) + end + + result = keep_const("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_plain_tpl_top_level_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + result = id("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_transparent_alias_top_level_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Id T + + ---@generic T + ---@param value T + ---@return Id + function id(value) + end + + result = id("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_plain_tpl_top_level_return_preserves_primitive_literal_union() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + ---@alias Choice "left" | "right" + + ---@type Choice + local choice + + result = id(choice) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("\"left\" | \"right\"")); + } + + #[test] + fn test_primitive_constraint_preserves_literal_candidate() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T: string + ---@param value T + ---@return T + function constrained(value) + end + + result = constrained("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_contextual_widening_keeps_bare_literal_but_widens_nested_literals() { + use crate::{LuaMemberKey, LuaObjectType, WideningContext, widen_type_with_context}; + use smol_str::SmolStr; + + let mut ws = VirtualWorkspace::new(); + let bare = LuaType::StringConst(SmolStr::new("mode").into()); + assert_eq!( + widen_type_with_context(bare.clone(), WideningContext::Root), + bare + ); + + let object = LuaType::Object( + LuaObjectType::new_with_fields( + [( + LuaMemberKey::Name("kind".into()), + LuaType::StringConst(SmolStr::new("mode").into()), + )] + .into_iter() + .collect(), + Vec::new(), + ) + .into(), + ); + let widened = widen_type_with_context(object, WideningContext::Root); + assert_eq!(widened, ws.ty("{ kind: string }")); + } + #[test] fn test_extends_true() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs index 47ef6dfb4..606957d47 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs @@ -130,6 +130,44 @@ mod test { assert_eq!(ws.humanize_type(b_ty), "number"); } + #[test] + fn test_unpack_alias_call_uses_uninferred_generic_default() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@generic T = [string, number] + ---@return std.Unpack + function f() + end + + a, b = f() + "#, + ); + + let a_ty = ws.expr_ty("a"); + let b_ty = ws.expr_ty("b"); + assert_eq!(ws.humanize_type(a_ty), "string"); + assert_eq!(ws.humanize_type(b_ty), "number"); + } + + #[test] + fn test_unpack_alias_call_uses_uninferred_generic_constraint() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@generic T: string[] + ---@return std.Unpack + function f() + end + + a = f() + "#, + ); + + let a_ty = ws.expr_ty("a"); + assert_eq!(ws.humanize_type(a_ty), "string?"); + } + #[test] fn test_issue_484() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index 1ad1c18ef..e55590b35 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -33,7 +33,7 @@ pub fn build_call_constraint_context( let mut substitutor = TypeSubstitutor::new(); let generic_tpls = collect_func_tpl_ids(¶ms); if !generic_tpls.is_empty() { - substitutor.add_need_infer_tpls(generic_tpls); + substitutor.prepare_inference_slots(generic_tpls); } // 读取显式传入的泛型实参 @@ -42,7 +42,7 @@ pub fn build_call_constraint_context( DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for (idx, doc_type) in type_list.get_types().enumerate() { let ty = infer_doc_type(doc_ctx, &doc_type); - substitutor.insert_type(GenericTplId::Func(idx as u32), ty, true); + substitutor.bind_type(GenericTplId::Func(idx as u32), ty); } } @@ -261,16 +261,16 @@ fn record_generic_assignment( match param_type { LuaType::TplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), true); + substitutor.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::ConstTplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), false); + substitutor.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::StrTplRef(str_tpl_ref) => { - substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone(), true); + substitutor.bind_type(str_tpl_ref.get_tpl_id(), arg_type.clone()); } LuaType::Variadic(variadic) => { if let Some(inner) = variadic.get_type(0) { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 8ac1a2644..de85aa424 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -96,7 +96,7 @@ fn complete_type_generic_args_inner( for (idx, generic_param) in generic_params.iter().enumerate() { if let Some(provided_arg) = provided_args.get(idx) { let provided_arg = provided_arg.clone(); - substitutor.insert_type(GenericTplId::Type(idx as u32), provided_arg.clone(), true); + substitutor.bind_type(GenericTplId::Type(idx as u32), provided_arg.clone()); params.push(provided_arg); continue; } @@ -115,7 +115,7 @@ fn complete_type_generic_args_inner( completed_type.ty }; let instantiated = instantiate_type_generic(db, &default_type, &substitutor); - substitutor.insert_type(GenericTplId::Type(idx as u32), instantiated.clone(), true); + substitutor.bind_type(GenericTplId::Type(idx as u32), instantiated.clone()); params.push(instantiated); } else { missing_required_count += 1; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs similarity index 85% rename from crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs rename to crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs index 1c8e39996..b58290125 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs @@ -12,7 +12,6 @@ use crate::{ semantic::{ LuaInferCache, generic::{ - instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, @@ -21,18 +20,19 @@ use crate::{ }, infer::InferFailReason, infer_expr, - overload_resolve::{callable_accepts_args, resolve_signature_by_args}, + overload_resolve::{ + callable_accepts_args, collect_callable_overload_groups, resolve_signature_by_args, + }, }, }; use crate::{ GenericTpl, LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, - collect_callable_overload_groups, infer_node_semantic_decl, - tpl_pattern_match_args_skip_unknown, + infer_node_semantic_decl, tpl_pattern_match_args_skip_unknown, }; use super::{TypeSubstitutor, instantiate_type_generic}; -pub fn instantiate_func_generic( +pub fn infer_call_func_generic( db: &DbIndex, cache: &mut LuaInferCache, func: &LuaFunctionType, @@ -53,36 +53,35 @@ pub fn instantiate_func_generic( .get_args() .collect::>(); let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: Some(call_expr.clone()), - }; - if !generic_tpls.is_empty() { - context.substitutor.add_need_infer_tpls(generic_tpls); + { + let mut context = TplContext::new(db, cache, &mut substitutor, Some(call_expr.clone())); + if !generic_tpls.is_empty() { + context.substitutor.prepare_inference_slots(generic_tpls); - if let Some(type_list) = call_expr.get_call_generic_type_list() { - // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 - apply_call_generic_type_list(db, file_id, &mut context, &type_list); - } else { - // 如果没有指定泛型, 则需要从调用参数中推断 - infer_generic_types_from_call( - db, - &mut context, - func, - &call_expr, - &mut func_params, - &arg_exprs, - )?; + if let Some(type_list) = call_expr.get_call_generic_type_list() { + // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 + apply_call_generic_type_list(db, file_id, &mut context, &type_list); + } else { + // 如果没有指定泛型, 则需要从调用参数中推断 + infer_generic_types_from_call( + db, + &mut context, + func, + &call_expr, + &mut func_params, + &arg_exprs, + )?; + } } } if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { substitutor.add_self_type(self_type); } + substitutor.finalize_inferred_types(db, func_generic_tpls(func).iter(), func.get_ret()); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { + let func_ty = LuaType::DocFunction(func.clone().into()); + if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_ty, &substitutor) { Ok(f.deref().clone()) } else { Ok(func.clone()) @@ -100,11 +99,11 @@ fn apply_call_generic_type_list( let typ = infer_doc_type(doc_ctx, &doc_type); context .substitutor - .insert_type(GenericTplId::Func(i as u32), typ, true); + .bind_type(GenericTplId::Func(i as u32), typ); } } -pub fn as_doc_function_type( +fn as_doc_function_type( db: &DbIndex, callable_type: &LuaType, ) -> Result>, InferFailReason> { @@ -197,7 +196,7 @@ fn uses_erased_function_param(callable: &LuaFunctionType, call_arg_types: &[LuaT }) } -pub fn infer_callable_return_from_remaining_args( +fn infer_callable_return_from_remaining_args( context: &mut TplContext, callable_type: &LuaType, arg_exprs: &[LuaExpr], @@ -252,27 +251,36 @@ fn instantiate_callable_from_arg_types( .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls.clone()); - let mut callable_context = TplContext { - db: context.db, - cache: context.cache, - substitutor: &mut callable_substitutor, - call_expr: context.call_expr.clone(), - }; - if tpl_pattern_match_args_skip_unknown( - &mut callable_context, - &callable_param_types, - call_arg_types, - ) - .is_err() + callable_substitutor.prepare_inference_slots(callable_tpls.clone()); { - return None; + let mut callable_context = TplContext::new( + context.db, + context.cache, + &mut callable_substitutor, + context.call_expr.clone(), + ); + if tpl_pattern_match_args_skip_unknown( + &mut callable_context, + &callable_param_types, + call_arg_types, + ) + .is_err() + { + return None; + } } - let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { - LuaType::DocFunction(func) => func, - _ => callable.clone(), - }; + callable_substitutor.finalize_inferred_types( + context.db, + callable_generic_tpls(callable).iter(), + callable.get_ret(), + ); + let callable_ty = LuaType::DocFunction(callable.clone()); + let instantiated = + match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; let unresolved_return_tpls = { let mut tpl_ids = HashSet::new(); instantiated.get_ret().visit_type(&mut |ty| { @@ -299,9 +307,10 @@ fn instantiate_callable_from_arg_types( } for tpl_id in callback_return_tpls { - callable_substitutor.insert_type(tpl_id, LuaType::Unknown, true); + callable_substitutor.bind_type(tpl_id, LuaType::Unknown); } - match instantiate_doc_function(context.db, callable, &callable_substitutor) { + let callable_ty = LuaType::DocFunction(callable.clone()); + match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { LuaType::DocFunction(func) => Some(func), _ => None, } @@ -377,6 +386,27 @@ fn collect_func_tpl_ids(func: &LuaFunctionType) -> (HashSet, bool) (generic_tpls, contain_self) } +fn func_generic_tpls(func: &LuaFunctionType) -> Vec> { + let mut generic_tpls = Vec::new(); + func.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + if generic_tpl.get_tpl_id().is_func() + && !generic_tpls + .iter() + .any(|it: &Arc| it.get_tpl_id() == generic_tpl.get_tpl_id()) + { + generic_tpls.push(generic_tpl.clone()); + } + } + _ => {} + }); + generic_tpls +} + +fn callable_generic_tpls(callable: &LuaFunctionType) -> Vec> { + func_generic_tpls(callable) +} + fn collect_func_tpl_with_fallback_deps( generic_tpl: &GenericTpl, generic_tpls: &mut HashSet, @@ -488,7 +518,7 @@ fn infer_generic_types_from_call( break; } - if context.substitutor.is_infer_all_tpl() { + if !context.substitutor.has_unresolved_inference_slots() { break; } @@ -511,7 +541,6 @@ fn infer_generic_types_from_call( Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 Err(e) => return Err(e), }; - if let Some(return_pattern) = as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) { @@ -553,7 +582,7 @@ fn infer_generic_types_from_call( } } - if !context.substitutor.is_infer_all_tpl() { + if context.substitutor.has_unresolved_inference_slots() { for (func_param_type, call_arg_expr) in unresolve_tpls { let closure_type = infer_expr(db, context.cache, call_arg_expr)?; @@ -573,7 +602,7 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { for (i, generic_param) in generic.iter().enumerate() { let tpl_id = GenericTplId::Type(i as u32); let param = build_self_generic_arg(db, generic_param, &substitutor); - substitutor.insert_type(tpl_id, param.clone(), true); + substitutor.bind_type(tpl_id, param.clone()); params.push(param); } let generic = LuaGenericType::new(id.clone(), params); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs new file mode 100644 index 000000000..ba4cae0a4 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs @@ -0,0 +1,405 @@ +use std::{ops::Deref, sync::Arc}; + +use hashbrown::HashMap; + +use crate::{ + DbIndex, GenericParam, GenericTpl, LuaArrayType, LuaConditionalType, LuaFunctionType, + LuaGenericType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, + LuaType, LuaUnionType, TypeOps, TypeSubstitutor, VariadicType, instantiate_type_generic, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum TplCandidateSource { + Plain, + ConstPreserving, + Finalized, +} + +pub(in crate::semantic::generic) fn finalize_inferred_tpl_candidate( + db: &DbIndex, + tpl: &GenericTpl, + raw_candidate: &LuaType, + candidate_source: TplCandidateSource, + top_level: bool, + return_top_level: bool, + substitutor: &TypeSubstitutor, +) -> LuaType { + if candidate_source == TplCandidateSource::ConstPreserving { + return raw_candidate.clone(); + } + + let primitive_constraint = tpl + .get_constraint() + .map(|constraint| { + let constraint = instantiate_type_generic(db, constraint, substitutor); + is_primitive_or_literal_type(&constraint) + }) + .unwrap_or(false); + let candidate = if primitive_constraint || !top_level || return_top_level { + raw_candidate.clone() + } else { + match raw_candidate { + LuaType::FloatConst(_) => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, + LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, + _ => raw_candidate.clone(), + } + }; + widen_finalized_candidate_type(db, candidate, WideningContext::Root) +} + +fn is_primitive_or_literal_type(ty: &LuaType) -> bool { + match ty { + LuaType::String + | LuaType::Number + | LuaType::Integer + | LuaType::Boolean + | LuaType::StringConst(_) + | LuaType::DocStringConst(_) + | LuaType::IntegerConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::BooleanConst(_) + | LuaType::DocBooleanConst(_) => true, + LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_primitive_or_literal_type), + LuaType::Union(union) => union.into_vec().iter().any(is_primitive_or_literal_type), + LuaType::MultiLineUnion(union) => union + .get_unions() + .iter() + .any(|(ty, _)| is_primitive_or_literal_type(ty)), + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => is_primitive_or_literal_type(base), + VariadicType::Multi(types) => types.iter().any(is_primitive_or_literal_type), + }, + LuaType::Call(call) => call.get_operands().iter().any(is_primitive_or_literal_type), + _ => false, + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WideningContext { + Root, + UnionMember, + ObjectProperty, + ArrayElement, + TupleElement, + VariadicElement, +} + +fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningContext) -> LuaType { + match ty { + LuaType::TableConst(table_id) => { + table_const_to_object(db, table_id).unwrap_or(LuaType::Table) + } + LuaType::Object(object) => { + let fields = object + .get_fields() + .iter() + .map(|(key, ty)| { + ( + key.clone(), + widen_finalized_candidate_type( + db, + ty.clone(), + WideningContext::ObjectProperty, + ), + ) + }) + .collect(); + let index_access = object + .get_index_access() + .iter() + .map(|(key, value)| { + ( + widen_type_with_context(key.clone(), WideningContext::ObjectProperty), + widen_finalized_candidate_type( + db, + value.clone(), + WideningContext::ObjectProperty, + ), + ) + }) + .collect(); + LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) + } + LuaType::Array(array) => { + let element_context = match context { + WideningContext::TupleElement => WideningContext::TupleElement, + _ => WideningContext::ArrayElement, + }; + let base = + widen_finalized_candidate_type(db, array.get_base().clone(), element_context); + LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) + } + LuaType::Tuple(tuple) => { + let types = tuple + .get_types() + .iter() + .cloned() + .map(|ty| widen_finalized_candidate_type(db, ty, WideningContext::TupleElement)) + .collect(); + LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + LuaType::Union(union) => { + let member_context = if matches!(context, WideningContext::Root) { + WideningContext::Root + } else { + WideningContext::UnionMember + }; + LuaType::Union( + LuaUnionType::from_vec( + union + .into_vec() + .into_iter() + .map(|ty| widen_finalized_candidate_type(db, ty, member_context)) + .collect(), + ) + .into(), + ) + } + ty => widen_type_with_context(ty, context), + } +} + +pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType { + let widen_literals = !matches!(context, WideningContext::Root); + + match ty { + LuaType::FloatConst(_) if widen_literals => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) if widen_literals => { + LuaType::Integer + } + LuaType::DocStringConst(_) | LuaType::StringConst(_) if widen_literals => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) if widen_literals => { + LuaType::Boolean + } + LuaType::Array(array) => { + let element_context = match context { + WideningContext::TupleElement => WideningContext::TupleElement, + _ => WideningContext::ArrayElement, + }; + let base = widen_type_with_context(array.get_base().clone(), element_context); + LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) + } + LuaType::Tuple(tuple) => { + let types = tuple + .get_types() + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::TupleElement)) + .collect(); + LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + LuaType::Object(object) => { + let fields = object + .get_fields() + .iter() + .map(|(key, ty)| { + ( + key.clone(), + widen_type_with_context(ty.clone(), WideningContext::ObjectProperty), + ) + }) + .collect(); + let index_access = object + .get_index_access() + .iter() + .map(|(key, value)| { + ( + widen_type_with_context(key.clone(), WideningContext::ObjectProperty), + widen_type_with_context(value.clone(), WideningContext::ObjectProperty), + ) + }) + .collect(); + LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) + } + LuaType::Union(union) => { + let member_context = if matches!(context, WideningContext::Root) { + WideningContext::Root + } else { + WideningContext::UnionMember + }; + LuaType::Union( + LuaUnionType::from_vec( + union + .into_vec() + .into_iter() + .map(|ty| widen_type_with_context(ty, member_context)) + .collect(), + ) + .into(), + ) + } + LuaType::MultiLineUnion(multi) => LuaType::MultiLineUnion( + crate::LuaMultiLineUnion::new( + multi + .get_unions() + .iter() + .map(|(ty, description)| { + ( + widen_type_with_context(ty.clone(), WideningContext::UnionMember), + description.clone(), + ) + }) + .collect(), + ) + .into(), + ), + LuaType::Intersection(intersection) => LuaType::Intersection( + crate::LuaIntersectionType::new( + intersection + .get_types() + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::UnionMember)) + .collect(), + ) + .into(), + ), + LuaType::Variadic(variadic) => LuaType::Variadic( + match variadic.deref() { + VariadicType::Base(base) => VariadicType::Base(widen_type_with_context( + base.clone(), + WideningContext::VariadicElement, + )), + VariadicType::Multi(types) => VariadicType::Multi( + types + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::VariadicElement)) + .collect(), + ), + } + .into(), + ), + LuaType::Generic(generic) => LuaType::Generic( + LuaGenericType::new( + generic.get_base_type_id(), + generic + .get_params() + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::Root)) + .collect(), + ) + .into(), + ), + LuaType::TableGeneric(params) => LuaType::TableGeneric( + params + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::Root)) + .collect::>() + .into(), + ), + LuaType::DocFunction(func) => LuaType::DocFunction( + LuaFunctionType::new( + func.get_async_state(), + func.is_colon_define(), + func.is_variadic(), + func.get_params() + .iter() + .map(|(name, ty)| { + ( + name.clone(), + ty.clone() + .map(|ty| widen_type_with_context(ty, WideningContext::Root)), + ) + }) + .collect(), + widen_type_with_context(func.get_ret().clone(), WideningContext::Root), + ) + .into(), + ), + LuaType::TypeGuard(guard) => LuaType::TypeGuard( + widen_type_with_context(guard.deref().clone(), WideningContext::Root).into(), + ), + LuaType::Conditional(conditional) => LuaType::Conditional( + LuaConditionalType::new( + widen_type_with_context( + conditional.get_checked_type().clone(), + WideningContext::Root, + ), + widen_type_with_context( + conditional.get_extends_type().clone(), + WideningContext::Root, + ), + widen_type_with_context(conditional.get_true_type().clone(), WideningContext::Root), + widen_type_with_context( + conditional.get_false_type().clone(), + WideningContext::Root, + ), + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ), + LuaType::Mapped(mapped) => LuaType::Mapped(Arc::new(LuaMappedType::new( + ( + mapped.param.0, + GenericParam::new( + mapped.param.1.name.clone(), + mapped + .param + .1 + .type_constraint + .clone() + .map(|ty| widen_type_with_context(ty, WideningContext::Root)), + mapped + .param + .1 + .default_type + .clone() + .map(|ty| widen_type_with_context(ty, WideningContext::Root)), + mapped.param.1.attributes.clone(), + ), + ), + widen_type_with_context(mapped.value.clone(), WideningContext::Root), + mapped.is_readonly, + mapped.is_optional, + ))), + ty => ty, + } +} + +fn table_const_to_object( + db: &DbIndex, + table_id: crate::InFiled, +) -> Option { + let owner = LuaMemberOwner::Element(table_id); + let members = db.get_member_index().get_members(&owner)?; + let mut fields = HashMap::new(); + let mut index_access = Vec::new(); + + for member in members { + let value = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Unknown); + let value = widen_finalized_candidate_type(db, value, WideningContext::ObjectProperty); + + match member.get_key() { + LuaMemberKey::Name(_) | LuaMemberKey::Integer(_) => { + fields + .entry(member.get_key().clone()) + .and_modify(|prev| { + *prev = TypeOps::Union.apply(db, prev, &value); + }) + .or_insert(value); + } + LuaMemberKey::ExprType(key) => { + index_access.push(( + widen_type_with_context(key.clone(), WideningContext::ObjectProperty), + value, + )); + } + LuaMemberKey::None => {} + } + } + + Some(LuaType::Object( + LuaObjectType::new_with_fields(fields, index_access).into(), + )) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index 226a0fb4f..d2aadf06d 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -11,7 +11,7 @@ use crate::{ use super::{ get_default_constructor, instantiate_type_generic, instantiate_type_generic_with_context, }; -use crate::semantic::generic::type_substitutor::GenericInstantiateContext; +use crate::semantic::generic::type_substitutor::{GenericInstantiateContext, TplBinding}; #[derive(Debug, Clone, Copy)] enum InferVariance { @@ -125,7 +125,7 @@ fn instantiate_distributed_conditional( let mut result = LuaType::Never; for member in members { let mut member_substitutor = context.substitutor.clone(); - member_substitutor.replace_type(tpl_id, member, false); + member_substitutor.bind(tpl_id, TplBinding::ReplaceConstType(member)); let member_context = context.with_substitutor(&member_substitutor); let member_result = instantiate_conditional_once(&member_context, conditional); result = TypeOps::Union.apply(context.db, &result, &member_result); @@ -172,7 +172,7 @@ fn instantiate_true_branch( let mut true_substitutor = context.substitutor.clone(); for (tpl_id, ty) in infer_assignments { - true_substitutor.insert_conditional_infer_type(tpl_id, ty); + true_substitutor.bind(tpl_id, TplBinding::ConditionalInferType(ty)); } instantiate_type_generic(context.db, conditional.get_true_type(), &true_substitutor) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 57fae183d..c44efafa5 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -10,7 +10,10 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_with_context}; +use super::{ + GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, + instantiate_type_generic_with_context, +}; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, @@ -76,7 +79,10 @@ pub(super) fn instantiate_alias_call( instantiate_select_call(&operands[0], &operands[1]) } - LuaAliasCallKind::Unpack => instantiate_unpack_call(context.db, &operands), + LuaAliasCallKind::Unpack => { + let operands = resolve_unpack_operands(context, operand_exprs); + instantiate_unpack_call(context.db, &operands) + } LuaAliasCallKind::RawGet => { if operands.len() != 2 { return LuaType::Unknown; @@ -215,6 +221,44 @@ fn instantiate_select_call(source: &LuaType, index: &LuaType) -> LuaType { } } +fn resolve_unpack_operands( + context: &GenericInstantiateContext, + operand_exprs: &[LuaType], +) -> Vec { + operand_exprs + .iter() + .enumerate() + .map(|(index, operand)| { + if index != 0 { + return instantiate_type_generic_with_context(context, operand); + } + let raw = match operand { + LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => context + .substitutor + .get(tpl_ref.get_tpl_id()) + .and_then(|value| match value { + SubstitutorValue::None => None, + SubstitutorValue::Type(ty) => Some(ty.raw().clone()), + SubstitutorValue::MultiTypes { raw_types, .. } => Some(LuaType::Variadic( + VariadicType::Multi(raw_types.clone()).into(), + )), + SubstitutorValue::Params(params) => Some( + params + .first() + .unwrap_or(&(String::new(), None)) + .1 + .clone() + .unwrap_or(LuaType::Unknown), + ), + SubstitutorValue::MultiBase(base) => Some(base.clone()), + }), + _ => None, + }; + raw.unwrap_or_else(|| instantiate_type_generic_with_context(context, operand)) + }) + .collect() +} + fn instantiate_unpack_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.is_empty() { return LuaType::Unknown; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index ae9d07e77..571244289 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,10 +1,11 @@ mod complete_generic_args; +mod infer_call_func_generic; +mod inference_widening; mod instantiate_conditional_generic; -mod instantiate_func_generic; mod instantiate_special_generic; use hashbrown::{HashMap, HashSet}; -use std::{ops::Deref, sync::Arc}; +use std::ops::Deref; use crate::{ DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, @@ -14,7 +15,6 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaType, LuaUnionType, VariadicType, }, - semantic::infer::InferFailReason, }; use super::type_substitutor::{ @@ -23,94 +23,23 @@ use super::type_substitutor::{ pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; -pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; +pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infer_self_type}; +pub(in crate::semantic::generic) use inference_widening::{ + TplCandidateSource, finalize_inferred_tpl_candidate, +}; +pub use inference_widening::{WideningContext, widen_type_with_context}; pub use instantiate_special_generic::get_keyof_members; -pub(crate) fn collect_callable_overload_groups( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, -) -> Result<(), InferFailReason> { - let mut visiting_aliases = HashSet::new(); - collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) -} - -fn collect_callable_overload_groups_inner( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, - visiting_aliases: &mut HashSet, -) -> Result<(), InferFailReason> { - match callable_type { - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { - return Ok(()); - }; - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - - let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(type_id); - result?; - } - LuaType::Generic(generic) => { - let type_id = generic.get_base_type_id(); - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { - visiting_aliases.remove(&type_id); - return Ok(()); - }; - - let result = if let Some(origin_type) = - type_decl.get_alias_origin(db, Some(&substitutor)) - { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(&type_id); - result?; - } - LuaType::Union(union) => { - for member in union.into_vec() { - collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; - } - } - LuaType::Intersection(intersection) => { - for member in intersection.get_types() { - collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; - } - } - LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), - LuaType::Signature(sig_id) => { - let Some(signature) = db.get_signature_index().get(sig_id) else { - return Ok(()); - }; - let mut overloads = signature.overloads.to_vec(); - overloads.push(signature.to_doc_func_type()); - groups.push(overloads); - } - _ => {} - } - - Ok(()) -} - pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, substitutor: &TypeSubstitutor, ) -> LuaType { let context = GenericInstantiateContext::new(db, substitutor); - instantiate_type_generic_with_context(&context, ty) + match ty { + LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context(&context, doc_func), + _ => instantiate_type_generic_with_context(&context, ty), + } } pub(super) fn instantiate_type_generic_with_context( @@ -199,7 +128,7 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) match value { SubstitutorValue::None => new_types .push(instantiate_uninferred_tpl_fallback(tpl, context)), - SubstitutorValue::MultiTypes(types) => { + SubstitutorValue::MultiTypes { types, .. } => { for typ in types { new_types.push(typ.clone()); } @@ -209,7 +138,7 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => new_types.push(ty.default().clone()), + SubstitutorValue::Type(ty) => new_types.push(ty.resolved().clone()), SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), } } else { @@ -229,15 +158,6 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } -pub fn instantiate_doc_function( - db: &DbIndex, - doc_func: &LuaFunctionType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_doc_function_with_context(&context, doc_func) -} - fn instantiate_doc_function_with_context( context: &GenericInstantiateContext, doc_func: &LuaFunctionType, @@ -266,11 +186,11 @@ fn instantiate_doc_function_with_context( new_params.push((origin_param.0.clone(), Some(ty))); } SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + let resolved_type = ty.resolved().clone(); // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple - if let LuaType::Tuple(tuple) = resolved_type { + if let LuaType::Tuple(tuple) = &resolved_type { let base_index = new_params.len(); for (i, typ) in tuple.get_types().iter().enumerate() { let param_name = format!("var{}", base_index + i); @@ -288,7 +208,7 @@ fn instantiate_doc_function_with_context( new_params.push(( origin_param.0.clone(), Some(LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), + VariadicType::Base(resolved_type).into(), )), )); } @@ -297,7 +217,7 @@ fn instantiate_doc_function_with_context( new_params.push(param.clone()); } } - SubstitutorValue::MultiTypes(types) => { + SubstitutorValue::MultiTypes { types, .. } => { for (i, typ) in types.iter().enumerate() { let param_name = format!("var{}", i); new_params.push((param_name, Some(typ.clone()))); @@ -411,15 +331,6 @@ fn instantiate_intersection( ) } -pub fn instantiate_generic( - db: &DbIndex, - generic: &LuaGenericType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_generic_with_context(&context, generic) -} - fn instantiate_generic_with_context( context: &GenericInstantiateContext, generic: &LuaGenericType, @@ -481,8 +392,10 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> SubstitutorValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context); } - SubstitutorValue::Type(ty) => return ty.default().clone(), - SubstitutorValue::MultiTypes(types) => { + SubstitutorValue::Type(ty) => { + return ty.resolved().clone(); + } + SubstitutorValue::MultiTypes { types, .. } => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } SubstitutorValue::Params(params) => { @@ -506,8 +419,10 @@ fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateConte SubstitutorValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context); } - SubstitutorValue::Type(ty) => return ty.raw().clone(), - SubstitutorValue::MultiTypes(types) => { + SubstitutorValue::Type(ty) => { + return ty.resolved().clone(); + } + SubstitutorValue::MultiTypes { types, .. } => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } SubstitutorValue::Params(params) => { @@ -570,18 +485,16 @@ fn instantiate_variadic_type( }; } SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + let resolved_type = ty.resolved().clone(); if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never ) { - return resolved_type.clone(); + return resolved_type; } - return LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), - ); + return LuaType::Variadic(VariadicType::Base(resolved_type).into()); } - SubstitutorValue::MultiTypes(types) => { + SubstitutorValue::MultiTypes { types, .. } => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } SubstitutorValue::Params(params) => { @@ -704,7 +617,7 @@ fn instantiate_mapped_value( replacement: &LuaType, ) -> LuaType { let mut local_substitutor = context.substitutor.clone(); - local_substitutor.insert_type(tpl_id, replacement.clone(), true); + local_substitutor.bind_type(tpl_id, replacement.clone()); let local_context = context.with_substitutor(&local_substitutor); let mut result = instantiate_type_generic_with_context(&local_context, &mapped.value); // 根据 readonly 和 optional 属性进行处理 diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 90e34baa3..582e1e9ee 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -11,7 +11,6 @@ pub use call_constraint::{ }; use emmylua_parser::LuaAstNode; use emmylua_parser::LuaExpr; -pub(crate) use instantiate_type::collect_callable_overload_groups; pub use instantiate_type::*; use rowan::NodeOrToken; pub use tpl_context::TplContext; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 21dee2f3f..be0188d9a 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -298,4 +298,37 @@ result = { "# )); } + + #[test] + fn test_123() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param x T + ---@return T + function f(x) + return x + end + + A = f("hello") + B = f({value = "hello"}) + C = B.value + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(a_ty), "\"hello\""); + + let b_ty = ws.expr_ty("B"); + let b_desc = ws.humanize_type_detailed(b_ty); + assert!( + b_desc.contains("value: string"), + "unexpected type: {}", + b_desc + ); + + let c_ty = ws.expr_ty("C"); + assert_eq!(ws.humanize_type(c_ty), "string"); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs index 02200a4a6..bab917412 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs @@ -1,6 +1,10 @@ use emmylua_parser::LuaCallExpr; -use crate::{DbIndex, LuaInferCache, TypeSubstitutor}; +use super::instantiate_type::TplCandidateSource; +use crate::{ + DbIndex, GenericTplId, LuaInferCache, LuaType, TypeSubstitutor, + semantic::generic::type_substitutor::TplBinding, +}; #[derive(Debug)] pub struct TplContext<'a> { @@ -8,4 +12,66 @@ pub struct TplContext<'a> { pub cache: &'a mut LuaInferCache, pub substitutor: &'a mut TypeSubstitutor, pub call_expr: Option, + inference_top_level: bool, +} + +impl<'a> TplContext<'a> { + pub fn new( + db: &'a DbIndex, + cache: &'a mut LuaInferCache, + substitutor: &'a mut TypeSubstitutor, + call_expr: Option, + ) -> Self { + Self { + db, + cache, + substitutor, + call_expr, + inference_top_level: true, + } + } + + pub fn with_inference_top_level( + &mut self, + top_level: bool, + f: impl FnOnce(&mut Self) -> R, + ) -> R { + let previous = self.inference_top_level; + self.inference_top_level = previous && top_level; + let result = f(self); + self.inference_top_level = previous; + result + } + + pub(in crate::semantic::generic) fn insert_type( + &mut self, + tpl_id: GenericTplId, + replace_type: LuaType, + source: TplCandidateSource, + ) { + self.substitutor.bind( + tpl_id, + TplBinding::InferredType { + ty: replace_type, + source, + top_level: self.inference_top_level, + }, + ); + } + + pub(in crate::semantic::generic) fn insert_multi_types( + &mut self, + tpl_id: GenericTplId, + types: Vec, + source: TplCandidateSource, + ) { + self.substitutor.bind( + tpl_id, + TplBinding::InferredMultiTypes { + types, + source, + top_level: self.inference_top_level, + }, + ); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index aa556ca88..a6c2751fe 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -1,6 +1,6 @@ use crate::{ InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_generic, instantiate_type_generic, + TypeSubstitutor, instantiate_type_generic, semantic::generic::tpl_pattern::{ TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, }, @@ -125,7 +125,8 @@ fn generic_tpl_pattern_match_inner( _ => { // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 let substitutor = TypeSubstitutor::new(); - let typ = instantiate_generic(context.db, source_generic, &substitutor); + let generic_ty = LuaType::Generic(source_generic.clone().into()); + let typ = instantiate_type_generic(context.db, &generic_ty, &substitutor); if LuaType::from(source_generic.clone()) != typ { tpl_pattern_match(context, &typ, target)?; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 77900f2d4..657e04897 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -23,7 +23,10 @@ use crate::{ }, }; -use super::type_substitutor::TypeSubstitutor; +use super::{ + instantiate_type::TplCandidateSource::{ConstPreserving, Plain}, + type_substitutor::{TplBinding, TypeSubstitutor}, +}; use std::collections::HashMap; type TplPatternMatchResult = Result<(), InferFailReason>; @@ -159,16 +162,12 @@ pub fn tpl_pattern_match( match pattern { LuaType::TplRef(tpl) => { if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target.clone(), true); + context.insert_type(tpl.get_tpl_id(), target.clone(), Plain); } } LuaType::ConstTplRef(tpl) => { if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target, false); + context.insert_type(tpl.get_tpl_id(), target, ConstPreserving); } } LuaType::StrTplRef(str_tpl) => { @@ -176,33 +175,45 @@ pub fn tpl_pattern_match( let prefix = str_tpl.get_prefix(); let suffix = str_tpl.get_suffix(); let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.substitutor.insert_type( + context.insert_type( str_tpl.get_tpl_id(), get_str_tpl_infer_type(&type_name), - true, + Plain, ); } } LuaType::Array(array_type) => { - array_tpl_pattern_match(context, array_type.get_base(), &target)?; + context.with_inference_top_level(false, |context| { + array_tpl_pattern_match(context, array_type.get_base(), &target) + })?; } LuaType::TableGeneric(table_generic_params) => { - table_generic_tpl_pattern_match(context, table_generic_params, &target)?; + context.with_inference_top_level(false, |context| { + table_generic_tpl_pattern_match(context, table_generic_params, &target) + })?; } LuaType::Generic(generic) => { - generic_tpl_pattern_match(context, generic, &target)?; + context.with_inference_top_level(false, |context| { + generic_tpl_pattern_match(context, generic, &target) + })?; } LuaType::Union(union) => { union_tpl_pattern_match(context, union, &target)?; } LuaType::DocFunction(doc_func) => { - func_tpl_pattern_match(context, doc_func, &target)?; + context.with_inference_top_level(false, |context| { + func_tpl_pattern_match(context, doc_func, &target) + })?; } LuaType::Tuple(tuple) => { - tuple_tpl_pattern_match(context, tuple, &target)?; + context.with_inference_top_level(false, |context| { + tuple_tpl_pattern_match(context, tuple, &target) + })?; } LuaType::Object(obj) => { - object_tpl_pattern_match(context, obj, &target)?; + context.with_inference_top_level(false, |context| { + object_tpl_pattern_match(context, obj, &target) + })?; } _ => {} } @@ -210,16 +221,6 @@ pub fn tpl_pattern_match( Ok(()) } -pub fn constant_decay(typ: LuaType) -> LuaType { - match &typ { - LuaType::FloatConst(_) => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, - LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, - _ => typ, - } -} - fn object_tpl_pattern_match( context: &mut TplContext, origin_obj: &LuaObjectType, @@ -646,7 +647,7 @@ fn param_type_list_pattern_match_type_list( if i >= targets.len() { if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.insert_type(tpl_id, LuaType::Nil, Plain); } break; } @@ -658,7 +659,7 @@ fn param_type_list_pattern_match_type_list( SubstitutorValue::Type(_) => { continue; } - SubstitutorValue::MultiTypes(types) => { + SubstitutorValue::MultiTypes { types, .. } => { if types.len() > 1 { target_offset += types.len() - 1; } @@ -717,9 +718,7 @@ pub(crate) fn return_type_pattern_match_target_type( VariadicType::Base(source_base) => { if let LuaType::TplRef(type_ref) = source_base { let tpl_id = type_ref.get_tpl_id(); - context - .substitutor - .insert_type(tpl_id, target_base.clone(), true); + context.insert_type(tpl_id, target_base.clone(), Plain); } } VariadicType::Multi(source_multi) => { @@ -730,22 +729,14 @@ pub(crate) fn return_type_pattern_match_target_type( && let LuaType::TplRef(type_ref) = base { let tpl_id = type_ref.get_tpl_id(); - context.substitutor.insert_type( - tpl_id, - target_base.clone(), - true, - ); + context.insert_type(tpl_id, target_base.clone(), Plain); } break; } LuaType::TplRef(tpl_ref) => { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type( - tpl_id, - target_base.clone(), - true, - ); + context.insert_type(tpl_id, target_base.clone(), Plain); } _ => {} } @@ -784,12 +775,14 @@ fn func_varargs_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - substitutor.insert_params( + substitutor.bind( tpl_id, - target_rest_params - .iter() - .map(|(n, t)| (n.clone(), t.clone())) - .collect(), + TplBinding::VariadicParams( + target_rest_params + .iter() + .map(|(n, t)| (n.clone(), t.clone())) + .collect(), + ), ); } } @@ -810,7 +803,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.insert_type(tpl_id, LuaType::Nil, Plain); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -820,42 +813,28 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.insert_type(tpl_id, LuaType::Nil, Plain); } 1 => { - context.substitutor.insert_type( - tpl_id, - types[0].clone(), - true, - ); + context.insert_type(tpl_id, types[0].clone(), Plain); } _ => { - context.substitutor.insert_multi_types( - tpl_id, - types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); + context.insert_multi_types(tpl_id, types.to_vec(), Plain); } }, VariadicType::Base(base) => { - context.substitutor.insert_multi_base(tpl_id, base.clone()); + context + .substitutor + .bind(tpl_id, TplBinding::VariadicBase(base.clone())); } }, arg => { - context.substitutor.insert_type(tpl_id, arg.clone(), true); + context.insert_type(tpl_id, arg.clone(), Plain); } } } _ => { - context.substitutor.insert_multi_types( - tpl_id, - target_rest_types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); + context.insert_multi_types(tpl_id, target_rest_types.to_vec(), Plain); } } } @@ -863,19 +842,17 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, false); + context.insert_type(tpl_id, LuaType::Nil, ConstPreserving); } 1 => { - context.substitutor.insert_type( - tpl_id, - target_rest_types[0].clone(), - false, - ); + context.insert_type(tpl_id, target_rest_types[0].clone(), ConstPreserving); } _ => { - context - .substitutor - .insert_multi_types(tpl_id, target_rest_types.to_vec()); + context.insert_multi_types( + tpl_id, + target_rest_types.to_vec(), + ConstPreserving, + ); } } } @@ -895,7 +872,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.get(i) { Some(t) => { - context.substitutor.insert_type(tpl_id, t.clone(), true); + context.insert_type(tpl_id, t.clone(), Plain); } None => { break; @@ -946,9 +923,10 @@ fn tuple_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - context - .substitutor - .insert_multi_base(tpl_id, target_array_base.get_base().clone()); + context.substitutor.bind( + tpl_id, + TplBinding::VariadicBase(target_array_base.get_base().clone()), + ); } } VariadicType::Multi(_) => {} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index b045bda1d..5f6f03be7 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -1,7 +1,8 @@ use hashbrown::{HashMap, HashSet}; -use super::tpl_pattern::constant_decay; -use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId}; +use super::instantiate_type::{TplCandidateSource, finalize_inferred_tpl_candidate}; +use crate::{DbIndex, GenericTpl, GenericTplId, LuaType, LuaTypeDeclId}; +use std::sync::Arc; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) enum UninferredTplPolicy { @@ -11,6 +12,24 @@ pub(super) enum UninferredTplPolicy { PreserveTplRef, } +pub(in crate::semantic::generic) enum TplBinding { + FinalizedType(LuaType), + InferredType { + ty: LuaType, + source: TplCandidateSource, + top_level: bool, + }, + ReplaceConstType(LuaType), + ConditionalInferType(LuaType), + VariadicParams(Vec<(String, Option)>), + InferredMultiTypes { + types: Vec, + source: TplCandidateSource, + top_level: bool, + }, + VariadicBase(LuaType), +} + #[derive(Debug)] pub struct GenericInstantiateContext<'a> { pub db: &'a DbIndex, @@ -78,7 +97,11 @@ impl TypeSubstitutor { for (i, ty) in type_array.into_iter().enumerate() { tpl_replace_map.insert( GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), + SubstitutorValue::Type(SubstitutorTypeValue::new( + ty, + TplCandidateSource::Finalized, + true, + )), ); } Self { @@ -93,7 +116,11 @@ impl TypeSubstitutor { for (i, ty) in type_array.into_iter().enumerate() { tpl_replace_map.insert( GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), + SubstitutorValue::Type(SubstitutorTypeValue::new( + ty, + TplCandidateSource::Finalized, + true, + )), ); } Self { @@ -103,7 +130,7 @@ impl TypeSubstitutor { } } - pub fn add_need_infer_tpls(&mut self, tpl_ids: HashSet) { + pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { for tpl_id in tpl_ids { // conditional infer id 只属于条件类型内部匹配, 不参与普通调用/类型泛型推导. if tpl_id.is_conditional_infer() { @@ -116,62 +143,95 @@ impl TypeSubstitutor { } } - pub fn is_infer_all_tpl(&self) -> bool { - for value in self.tpl_replace_map.values() { - if let SubstitutorValue::None = value { - return false; - } - } - true - } - - pub fn insert_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { - // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. - if tpl_id.is_conditional_infer() { - return; - } - - self.insert_type_value(tpl_id, SubstitutorTypeValue::new(replace_type, decay)); - } - - pub(super) fn replace_type( - &mut self, - tpl_id: GenericTplId, - replace_type: LuaType, - decay: bool, - ) { - if tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), - ); + pub fn has_unresolved_inference_slots(&self) -> bool { + self.tpl_replace_map + .values() + .any(|value| matches!(value, SubstitutorValue::None)) } - pub fn insert_conditional_infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { - // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. - if !tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, false)), - ); + pub fn bind_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { + self.bind(tpl_id, TplBinding::FinalizedType(replace_type)); } - fn insert_type_value(&mut self, tpl_id: GenericTplId, value: SubstitutorTypeValue) { - if !self.can_insert_type(tpl_id) { - return; + pub(in crate::semantic::generic) fn bind(&mut self, tpl_id: GenericTplId, binding: TplBinding) { + match binding { + TplBinding::ConditionalInferType(replace_type) => { + // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. + if !tpl_id.is_conditional_infer() { + return; + } + + self.tpl_replace_map.insert( + tpl_id, + SubstitutorValue::Type(SubstitutorTypeValue::new( + replace_type, + TplCandidateSource::ConstPreserving, + true, + )), + ); + } + TplBinding::ReplaceConstType(replace_type) => { + if tpl_id.is_conditional_infer() { + return; + } + + self.tpl_replace_map.insert( + tpl_id, + SubstitutorValue::Type(SubstitutorTypeValue::new( + replace_type, + TplCandidateSource::ConstPreserving, + true, + )), + ); + } + binding => { + // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. + if tpl_id.is_conditional_infer() || !self.can_bind(tpl_id) { + return; + } + + let value = match binding { + TplBinding::FinalizedType(replace_type) => { + SubstitutorValue::Type(SubstitutorTypeValue::new( + replace_type, + TplCandidateSource::Finalized, + true, + )) + } + TplBinding::InferredType { + ty, + source, + top_level, + } => SubstitutorValue::Type(SubstitutorTypeValue::new(ty, source, top_level)), + TplBinding::VariadicParams(params) => { + let params = params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(); + SubstitutorValue::Params(params) + } + TplBinding::InferredMultiTypes { + types, + source, + top_level, + } => SubstitutorValue::MultiTypes { + raw_types: types.clone(), + types, + source, + top_level, + }, + TplBinding::VariadicBase(type_base) => SubstitutorValue::MultiBase(type_base), + TplBinding::ReplaceConstType(_) | TplBinding::ConditionalInferType(_) => { + unreachable!("handled before regular binding") + } + }; + + self.tpl_replace_map.insert(tpl_id, value); + } } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Type(value)); } - fn can_insert_type(&self, tpl_id: GenericTplId) -> bool { + fn can_bind(&self, tpl_id: GenericTplId) -> bool { if let Some(value) = self.tpl_replace_map.get(&tpl_id) { return value.is_none(); } @@ -179,51 +239,7 @@ impl TypeSubstitutor { true } - pub fn insert_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - let params = params - .into_iter() - .map(|(name, ty)| (name, ty.map(into_ref_type))) - .collect(); - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Params(params)); - } - - pub fn insert_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiTypes(types)); - } - - pub fn insert_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); - } - - pub fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { + pub(super) fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { self.tpl_replace_map.get(&tpl_id) } @@ -234,6 +250,57 @@ impl TypeSubstitutor { } } + pub(super) fn finalize_inferred_types<'a>( + &mut self, + db: &DbIndex, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) { + for tpl in generic_tpls { + let tpl_id = tpl.get_tpl_id(); + let return_top_level = is_tpl_at_top_level(db, return_type, tpl_id); + let substitutor = self.clone(); + let Some(value) = self.tpl_replace_map.get_mut(&tpl_id) else { + continue; + }; + + match value { + SubstitutorValue::Type(ty) => { + ty.finalize(db, tpl.as_ref(), return_top_level, &substitutor) + } + SubstitutorValue::MultiTypes { + raw_types, + types, + source, + top_level, + } => { + if *source == TplCandidateSource::Finalized { + continue; + } + let finalized = types + .iter() + .map(|ty| { + finalize_inferred_tpl_candidate( + db, + tpl.as_ref(), + ty, + *source, + *top_level, + return_top_level, + &substitutor, + ) + }) + .collect(); + *raw_types = types.clone(); + *types = finalized; + *source = TplCandidateSource::Finalized; + *top_level = true; + } + _ => {} + } + } + } + pub fn check_recursion(&self, type_id: &LuaTypeDeclId) -> bool { if let Some(alias_type_id) = &self.alias_type_id && alias_type_id == type_id @@ -256,49 +323,67 @@ impl TypeSubstitutor { #[derive(Debug, Clone)] pub struct SubstitutorTypeValue { raw: LuaType, - decayed: DecayedType, -} - -#[derive(Debug, Clone)] -enum DecayedType { - Same, - Cached(LuaType), + finalized: Option, + source: TplCandidateSource, + top_level: bool, } impl SubstitutorTypeValue { - pub fn new(raw: LuaType, decay: bool) -> Self { + fn new(raw: LuaType, source: TplCandidateSource, top_level: bool) -> Self { let raw = into_ref_type(raw); - let decayed = if decay { - let decayed = into_ref_type(constant_decay(raw.clone())); - if decayed == raw { - DecayedType::Same - } else { - DecayedType::Cached(decayed) - } - } else { - DecayedType::Same - }; - Self { raw, decayed } + let finalized = (source == TplCandidateSource::Finalized).then(|| raw.clone()); + Self { + raw, + finalized, + source, + top_level, + } } pub fn raw(&self) -> &LuaType { &self.raw } - pub fn default(&self) -> &LuaType { - match &self.decayed { - DecayedType::Same => &self.raw, - DecayedType::Cached(decayed) => decayed, + pub(super) fn resolved(&self) -> &LuaType { + self.finalized.as_ref().unwrap_or(&self.raw) + } + + fn finalize( + &mut self, + db: &DbIndex, + tpl: &GenericTpl, + return_top_level: bool, + substitutor: &TypeSubstitutor, + ) { + if self.finalized.is_some() { + return; } + + self.finalized = Some(finalize_inferred_tpl_candidate( + db, + tpl, + &self.raw, + self.source, + self.top_level, + return_top_level, + substitutor, + )); + self.source = TplCandidateSource::Finalized; + self.top_level = true; } } #[derive(Debug, Clone)] -pub enum SubstitutorValue { +pub(super) enum SubstitutorValue { None, Type(SubstitutorTypeValue), Params(Vec<(String, Option)>), - MultiTypes(Vec), + MultiTypes { + raw_types: Vec, + types: Vec, + source: TplCandidateSource, + top_level: bool, + }, MultiBase(LuaType), } @@ -308,6 +393,71 @@ impl SubstitutorValue { } } +fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { + is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) +} + +fn is_tpl_at_top_level_with_guard( + db: &DbIndex, + ty: &LuaType, + tpl_id: GenericTplId, + visited_aliases: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::Generic(generic) => { + let type_decl_id = generic.get_base_type_id_ref(); + let Some(alias_param) = + get_transparent_alias_param_index(db, type_decl_id, visited_aliases) + else { + return false; + }; + + generic.get_params().get(alias_param).is_some_and(|param| { + is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) + }) + } + _ => false, + } +} + +fn get_transparent_alias_param_index( + db: &DbIndex, + type_decl_id: &LuaTypeDeclId, + visited_aliases: &mut HashSet, +) -> Option { + if !visited_aliases.insert(type_decl_id.clone()) { + return None; + } + + let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; + if !type_decl.is_alias() { + return None; + }; + let origin = type_decl.get_alias_ref()?; + + match origin { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + LuaType::Generic(generic) => { + get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) + .and_then(|alias_param| generic.get_params().get(alias_param)) + .and_then(|param| match param { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + _ => None, + }) + } + _ => None, + } +} + fn into_ref_type(ty: LuaType) -> LuaType { match ty { LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index cd1360cf2..9231e09e6 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -8,7 +8,7 @@ use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; -use crate::semantic::overload_resolve::callable_accepts_args; +use crate::semantic::overload_resolve::{callable_accepts_args, collect_callable_overload_groups}; use crate::{ AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, @@ -17,14 +17,11 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{ - TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, - instantiate_doc_function, - }, + generic::{TypeSubstitutor, get_tpl_ref_extend_type}, infer::narrow::get_type_at_call_expr_inline_cast, }, }; -use crate::{build_self_type, infer_self_type, instantiate_func_generic, semantic::infer_expr}; +use crate::{build_self_type, infer_call_func_generic, infer_self_type, semantic::infer_expr}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -136,7 +133,7 @@ pub fn infer_call_expr_func( let result = if let Ok(func_ty) = result { let func_ty = match func_ty.get_ret() { LuaType::Call(_) => { - match instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { + match infer_call_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { Ok(func_ty) => Arc::new(func_ty), Err(_) => func_ty, } @@ -223,7 +220,7 @@ fn infer_doc_function( call_expr: LuaCallExpr, ) -> InferCallFuncResult { if func.contain_tpl() { - let result = instantiate_func_generic(db, cache, func, call_expr)?; + let result = infer_call_func_generic(db, cache, func, call_expr)?; return Ok(Arc::new(result)); } @@ -271,9 +268,10 @@ fn filter_callable_overloads_by_call_args( let has_tpls = !callable_tpls.is_empty(); let mut substitutor = TypeSubstitutor::new(); - substitutor.add_need_infer_tpls(callable_tpls); + substitutor.prepare_inference_slots(callable_tpls); let match_func = if has_tpls { - match instantiate_doc_function(db, func, &substitutor) { + let func_ty = LuaType::DocFunction(func.clone()); + match instantiate_type_generic(db, &func_ty, &substitutor) { LuaType::DocFunction(doc_func) => doc_func, _ => func.clone(), } @@ -362,13 +360,15 @@ fn infer_type_doc_function( }; if has_generic_tpl { - let result = instantiate_func_generic(db, cache, &f, call_expr.clone())?; + let result = infer_call_func_generic(db, cache, &f, call_expr.clone())?; overloads.push(Arc::new(result)); } else if f.contain_self() { let mut substitutor = TypeSubstitutor::new(); let self_type = build_self_type(db, call_expr_type); substitutor.add_self_type(self_type); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, &f, &substitutor) + let func_ty = LuaType::DocFunction(f.clone()); + if let LuaType::DocFunction(f) = + instantiate_type_generic(db, &func_ty, &substitutor) { overloads.push(f); } @@ -903,6 +903,30 @@ mod tests { assert_eq!(ws.expr_ty("payload"), ws.ty("string")); } + #[test] + fn test_top_level_generic_literal_keeps_function_param_and_return_consistent() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + local function id(value) end + + id("hello") + "#, + ); + let call_expr = ws.get_node::(file_id); + let semantic_model = ws.analysis.compilation.get_semantic_model(file_id).unwrap(); + let func = semantic_model + .infer_call_expr_func(call_expr, None) + .unwrap(); + + let param_ty = func.get_params()[0].1.clone().unwrap(); + assert_eq!(ws.humanize_type(param_ty), "\"hello\""); + assert_eq!(ws.humanize_type(func.get_ret().clone()), "\"hello\""); + } + #[test] fn test_union_call_ignores_unresolved_alias_member() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index e915915b1..7b9d10f3a 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -16,7 +16,7 @@ use crate::{ var_ref_id::get_var_expr_var_ref_id, }, }, - semantic::instantiate_func_generic, + semantic::infer_call_func_generic, }; pub fn get_type_at_call_expr( @@ -226,7 +226,7 @@ fn get_type_guard_call_info( let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { let Ok(inst_func) = cache.with_no_flow(|cache| { - instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) + infer_call_func_generic(db, cache, func_type.as_ref(), call_expr) }) else { return Ok(None); }; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 78c464852..61728c1e1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,7 +7,7 @@ use crate::{ LuaSignature, LuaType, TypeOps, semantic::{ infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, - instantiate_func_generic, + infer_call_func_generic, }, }; @@ -579,7 +579,7 @@ fn instantiate_return_rows( return_type.clone(), ); match cache - .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) + .with_no_flow(|cache| infer_call_func_generic(db, cache, &func, call_expr.clone())) { Ok(instantiated) => instantiated.get_ret().clone(), Err(_) => return_type, diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs new file mode 100644 index 000000000..fd5f568af --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs @@ -0,0 +1,85 @@ +use std::sync::Arc; + +use hashbrown::HashSet; + +use crate::db_index::{DbIndex, LuaFunctionType, LuaType, LuaTypeDeclId}; + +use super::super::{generic::TypeSubstitutor, infer::InferFailReason}; + +pub(crate) fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(type_id); + result?; + } + LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); + return Ok(()); + }; + + let result = if let Some(origin_type) = + type_decl.get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(&type_id); + result?; + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + _ => {} + } + + Ok(()) +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index a6447a91c..203d97870 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,3 +1,4 @@ +mod collect_callable_overloads; mod resolve_signature_by_args; use std::sync::Arc; @@ -8,10 +9,11 @@ use crate::db_index::{DbIndex, LuaFunctionType, LuaType}; use super::{ LuaInferCache, - generic::instantiate_func_generic, + generic::infer_call_func_generic, infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; +pub(crate) use collect_callable_overloads::collect_callable_overload_groups; pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; pub fn resolve_signature( @@ -78,7 +80,7 @@ fn resolve_signature_by_generic( ) -> InferCallFuncResult { let mut instantiate_funcs = Vec::new(); for func in overloads { - let instantiate_func = instantiate_func_generic(db, cache, &func, call_expr.clone())?; + let instantiate_func = infer_call_func_generic(db, cache, &func, call_expr.clone())?; instantiate_funcs.push(Arc::new(instantiate_func)); } resolve_signature_by_args( diff --git a/crates/emmylua_ls/src/handlers/definition/goto_function.rs b/crates/emmylua_ls/src/handlers/definition/goto_function.rs index f60d9b395..971c74aae 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_function.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_function.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ LuaCompilation, LuaDeclId, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaSignatureId, - LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic, + LuaType, SemanticDeclLevel, SemanticModel, infer_call_func_generic, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, LuaSyntaxToken, LuaTokenKind, @@ -291,7 +291,7 @@ pub fn compare_function_types( call_expr: &LuaCallExpr, ) -> Option { if func.contain_tpl() { - let instantiated_func = instantiate_func_generic( + let instantiated_func = infer_call_func_generic( semantic_model.get_db(), &mut semantic_model.get_cache().borrow_mut(), func, diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 2402e078e..c8f3e4818 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -3,8 +3,8 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, - instantiate_func_generic, try_extract_signature_id_from_field, + TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, infer_call_func_generic, + instantiate_type_generic, try_extract_signature_id_from_field, }; use crate::handlers::hover::{ @@ -104,7 +104,7 @@ fn build_function_call_hover( signature.get_type_params(), signature.get_return_type(), ); - let instantiated_signature = instantiate_func_generic( + let instantiated_signature = infer_call_func_generic( db, &mut builder.semantic_model.get_cache().borrow_mut(), &base_function, @@ -486,7 +486,7 @@ fn instantiate_call_return_overloads( row_return_type, ); let instantiated_row = - instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) + infer_call_func_generic(db, &mut cache, &row_function, call_expr.clone()) .ok() .map(|func| match func.get_ret() { LuaType::Variadic(variadic) => match variadic.as_ref() { @@ -702,8 +702,8 @@ fn hover_instantiate_function_type( return None; } match typ { - LuaType::DocFunction(f) => { - if let LuaType::DocFunction(f) = instantiate_doc_function(db, f, substitutor) { + LuaType::DocFunction(_) => { + if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, substitutor) { Some(f) } else { None diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index d4cd08773..61b81de1f 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -336,7 +336,7 @@ pub fn substitutor_form_expr( let mut substitutor = TypeSubstitutor::new(); if let LuaType::Generic(generic) = prefix_type { for (i, param) in generic.get_params().iter().enumerate() { - substitutor.insert_type(GenericTplId::Type(i as u32), param.clone(), true); + substitutor.bind_type(GenericTplId::Type(i as u32), param.clone()); } return Some(substitutor); } else { diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs index 8132c2f3e..c7f6c9066 100644 --- a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -28,19 +28,6 @@ mod tests { result } - fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { - let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); - for i in 0..600 { - let table_key = 3_121_212; - let field_key = 1_111_112 + i; - content.push_str(&format!( - "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", - i + 1, - )); - } - content - } - #[gtest] fn test_1() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -148,7 +135,22 @@ m.foo() Ok(()) } + #[cfg(feature = "full-test")] + fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { + let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); + for i in 0..600 { + let table_key = 3_121_212; + let field_key = 1_111_112 + i; + content.push_str(&format!( + "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", + i + 1, + )); + } + content + } + #[gtest] + #[cfg(feature = "full-test")] fn test_issue_1028_i18n_semantic_tokens_repeated_prefix_guard_chain() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); let content = make_issue_1028_repeated_prefix_guard_chain_content(); From 2b6798e7ad577b72a03beac5aed8ce94d229e8c1 Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Sat, 16 May 2026 04:30:41 +0800 Subject: [PATCH 02/10] update generic --- .../resources/std/builtin.lua | 7 + .../analyzer/doc/file_generic_index.rs | 49 ++- .../compilation/analyzer/doc/type_def_tags.rs | 5 +- .../analyzer/doc/type_generic_header.rs | 17 +- .../src/compilation/test/generic_builtin.rs | 237 ++++++++++ .../src/compilation/test/generic_test.rs | 172 +++++++- .../src/compilation/test/mod.rs | 1 + .../src/db_index/type/generic_param.rs | 9 +- .../generic/generic_constraint_mismatch.rs | 159 +++++-- .../instantiate_conditional_generic.rs | 412 +++++++++--------- .../instantiate_mapped_type.rs | 248 +++++++++++ .../instantiate_special_generic.rs | 64 ++- .../semantic/generic/instantiate_type/mod.rs | 379 ++++++++-------- .../tpl_pattern/generic_tpl_pattern.rs | 1 + .../src/semantic/generic/tpl_pattern/mod.rs | 8 +- .../src/semantic/generic/type_substitutor.rs | 271 ++++++++---- .../src/semantic/infer/infer_call/mod.rs | 13 +- .../src/semantic/infer/infer_index/mod.rs | 40 +- .../src/semantic/member/find_members.rs | 11 +- .../src/semantic/member/infer_raw_member.rs | 8 +- .../semantic/type_check/complex_type/mod.rs | 7 +- .../src/semantic/type_check/func_type.rs | 7 +- .../src/semantic/type_check/generic_type.rs | 10 +- .../src/semantic/type_check/simple_type.rs | 7 +- 24 files changed, 1516 insertions(+), 626 deletions(-) create mode 100644 crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index da7557002..af7194301 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -167,6 +167,13 @@ --- Extract from T those types that are assignable to U --- @alias Extract T extends U and T or never +--- +--- From T, pick a set of properties whose keys are in the union K +--- @alias Pick {[P in K]: T[P]; } + +--- +--- Construct a type with the properties of T except for those in type K. +--- @alias Omit Pick> --- attribute diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index e3ad351d2..fa35dd075 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -9,12 +9,24 @@ use crate::{GenericParam, GenericTpl, GenericTplId, LuaType}; pub trait GenericIndex: std::fmt::Debug { fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId; - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam); - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option; + + fn append_generic_params( + &mut self, + scope_id: GenericScopeId, + params: Vec, + ) -> Vec { + let mut appended = Vec::new(); for param in params { - self.append_generic_param(scope_id, param); + if let Some(tpl_id) = self.append_generic_param(scope_id, param.clone()) { + appended.push(param.with_tpl_id(Some(tpl_id))); + } } + appended } fn find_generic( @@ -63,16 +75,29 @@ impl GenericIndex for FileGenericIndex { scope_id } - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option { if let Some(scope) = self.scopes.get_mut(scope_id.id) { - scope.insert_param(param); + return Some(scope.insert_param(param)); } + None } - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { + fn append_generic_params( + &mut self, + scope_id: GenericScopeId, + params: Vec, + ) -> Vec { + let mut appended = Vec::new(); for param in params { - self.append_generic_param(scope_id, param); + if let Some(tpl_id) = self.append_generic_param(scope_id, param.clone()) { + appended.push(param.with_tpl_id(Some(tpl_id))); + } } + appended } /// Find generic parameter by position and name. @@ -131,10 +156,12 @@ impl FileGenericScope { self.next_tpl_id.is_func() } - fn insert_param(&mut self, param: GenericParam) { - let tpl_id = self.next_tpl_id; - self.next_tpl_id = self.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); + fn insert_param(&mut self, param: GenericParam) -> GenericTplId { + let tpl_id = param.tpl_id.unwrap_or(self.next_tpl_id); + let next_idx = self.next_tpl_id.get_idx().max(tpl_id.get_idx() + 1) as u32; + self.next_tpl_id = self.next_tpl_id.with_idx(next_idx); self.params.insert(param.name.to_string(), (tpl_id, param)); + tpl_id } fn contains(&self, position: TextSize) -> bool { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index d63ed8290..0fc9ad2b9 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -142,9 +142,10 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< alias_decl.get_id() }; + let type_node = tag.get_type()?; if tag.get_generic_decl_list().is_some() { let generic_params = get_type_generic_params(analyzer, &alias_decl_id); - let range = analyzer.comment.get_range(); + let range = type_node.get_range(); let scope_id = analyzer .type_context .generic_index @@ -155,7 +156,7 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< .append_generic_params(scope_id, generic_params); } - let mut origin_type = infer_type(&mut analyzer.type_context, tag.get_type()?); + let mut origin_type = infer_type(&mut analyzer.type_context, type_node); if alias_origin_reaches(analyzer.get_db(), &origin_type, &alias_decl_id) { origin_type = LuaType::Any; } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs index ed8601db4..dcd26943f 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs @@ -68,6 +68,7 @@ fn normalize_generic_params(db: &DbIndex, params: &[GenericParam]) -> Vec Option { + let scope = self.scopes.get_mut(scope_id.id)?; let tpl_id = scope.next_tpl_id; scope.next_tpl_id = scope.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); scope.params.push((tpl_id, param)); + Some(tpl_id) } fn find_generic( diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs new file mode 100644 index 000000000..798c5e667 --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs @@ -0,0 +1,237 @@ +#[cfg(test)] +mod test { + use crate::{DiagnosticCode, VirtualWorkspace}; + + #[test] + fn test_builtin_pick_preserves_selected_properties() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinPickUser + ---@field name string + ---@field age number + ---@field email string + ---@field nickname? string + + ---@type Pick + local picked + PickedName = picked.name + PickedAge = picked.age + PickedNickname = picked.nickname + + ---@type Pick + local pickedAll + PickedAllEmail = pickedAll.email + + ---@type Pick<{id: integer, enabled: boolean, label?: string}, "id" | "label"> + local pickedLiteral + PickedLiteralId = pickedLiteral.id + PickedLiteralLabel = pickedLiteral.label + "#, + ); + + assert_eq!(ws.expr_ty("PickedName"), ws.ty("string")); + assert_eq!(ws.expr_ty("PickedAge"), ws.ty("number")); + assert_eq!(ws.expr_ty("PickedNickname"), ws.ty("string?")); + assert_eq!(ws.expr_ty("PickedAllEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("PickedLiteralId"), ws.ty("integer")); + assert_eq!(ws.expr_ty("PickedLiteralLabel"), ws.ty("string?")); + } + + #[test] + fn test_builtin_pick_matches_ts6_key_constraint() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinPickConstraintUser + ---@field name string + ---@field age number + "#, + ); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Pick + local picked + local name = picked.name + "# + )); + } + + #[test] + fn test_builtin_pick_empty_keyof_domain_converges() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinEmptyPickClass + + ---@type Pick<{}, keyof {}> + local pickedEmptyObject + PickedEmptyObjectMissing = pickedEmptyObject.missing + + ---@type Pick + local pickedEmptyClass + PickedEmptyClassMissing = pickedEmptyClass.missing + "#, + ); + + assert_eq!(ws.expr_ty("PickedEmptyObjectMissing"), ws.ty("nil")); + assert_eq!(ws.expr_ty("PickedEmptyClassMissing"), ws.ty("nil")); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick<{}, keyof {}> + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Pick<{}, keyof {}> + local picked + local missing = picked.missing + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@class BuiltinEmptyPickDiagnosticClass + + ---@type Pick + local picked + local missing = picked.missing + "# + )); + } + + #[test] + fn test_builtin_omit_removes_properties() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinOmitUser + ---@field name string + ---@field age number + ---@field email string + ---@field nickname? string + + ---@type Omit + local omitted + OmittedName = omitted.name + OmittedAge = omitted.age + OmittedNickname = omitted.nickname + OmittedEmail = omitted.email + + ---@type Pick> + local pickedWithoutEmail + PickedWithoutEmailEmail = pickedWithoutEmail.email + + ---@type Omit<{id: integer, enabled: boolean, label?: string}, "enabled"> + local omittedLiteral + OmittedLiteralId = omittedLiteral.id + OmittedLiteralLabel = omittedLiteral.label + "#, + ); + + assert_eq!(ws.expr_ty("OmittedName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmittedAge"), ws.ty("number")); + assert_eq!(ws.expr_ty("OmittedNickname"), ws.ty("string?")); + assert_eq!(ws.expr_ty("PickedWithoutEmailEmail"), ws.ty("nil")); + assert_eq!(ws.expr_ty("OmittedEmail"), ws.ty("nil")); + assert_eq!(ws.expr_ty("OmittedLiteralId"), ws.ty("integer")); + assert_eq!(ws.expr_ty("OmittedLiteralLabel"), ws.ty("string?")); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local email = omitted.email + "# + )); + } + + #[test] + fn test_builtin_omit_matches_ts6_keyof_any_behavior() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinOmitKeyUser + ---@field name string + ---@field age number + ---@field email string + + ---@type Omit + local omitMissing + OmitMissingName = omitMissing.name + OmitMissingEmail = omitMissing.email + + ---@type Omit + local omitNever + OmitNeverName = omitNever.name + OmitNeverEmail = omitNever.email + + ---@type Omit + local omitAll + OmitAllName = omitAll.name + "#, + ); + + assert_eq!(ws.expr_ty("OmitMissingName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitMissingEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitNeverName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitNeverEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitAllName"), ws.ty("nil")); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Omit + local omitted + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local name = omitted.name + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local name = omitted.name + "# + )); + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 7aeea9050..5092c6e3f 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -3,8 +3,8 @@ mod test { use emmylua_parser::LuaClosureExpr; use crate::{ - DiagnosticCode, LuaSignatureId, LuaType, LuaTypeDeclId, VirtualWorkspace, - complete_type_generic_args, + DiagnosticCode, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId, TypeSubstitutor, + VirtualWorkspace, complete_type_generic_args, instantiate_type_generic, }; #[test] @@ -412,6 +412,172 @@ mod test { )); } + #[test] + fn test_keyof_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Keys keyof T + ---@alias ForwardKeys Keys + + ---@param key "a" | "b" + function accept(key) end + "#, + ); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type ForwardKeys<{ a: string, b: number }> + local key + accept(key) + "# + )); + } + + #[test] + fn test_mapped_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K]; } + ---@alias ForwardCopy Copy + + ---@type ForwardCopy<{ a: string, b: number }> + local copy + + A = copy.a + B = copy.b + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + assert_eq!(ws.expr_ty("B"), ws.ty("number")); + } + + #[test] + fn test_mapped_unresolved_key_domain_preserves_residual() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K]; } + + ---@generic T + ---@param value Copy + ---@return Copy + function keep(value) end + + ---@type Copy<{ a: string }> + local concrete + + A = keep(concrete).a + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + } + + #[test] + fn test_alias_argument_binding_ignores_shadowing_function_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Box fun(x: T): T + + ---@type Box + local f + + Result = f(1) + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Box + local f + + f(1) + "# + )); + } + + #[test] + fn test_alias_argument_binding_ignores_shadowing_mapped_key() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Shadow { [T in keyof { a: string }]: T; } + + ---@type Shadow + local value + + A = value.a + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty(r#""a""#)); + } + + #[test] + fn test_conditional_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Extract T extends U and T or never + ---@alias KeepA Extract + ---@alias Forward KeepA + "#, + ); + + let generic_ty = ws.ty(r#"Forward<"a" | "b">"#); + let instantiated = + instantiate_type_generic(ws.get_db_mut(), &generic_ty, &TypeSubstitutor::new()); + assert_eq!(instantiated, ws.ty(r#""a""#)); + } + + #[test] + fn test_nested_mapped_conditional_alias_residual_resolves() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Wrapper + ---@alias UnwrapFields { [K in keyof T]: T[K] extends Wrapper and U or T[K]; } + ---@alias Forward UnwrapFields + + ---@type Forward<{ a: Wrapper, b: number }> + local value + + A = value.a + B = value.b + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + assert_eq!(ws.expr_ty("B"), ws.ty("number")); + } + + #[test] + fn test_recursive_alias_instantiation_budget_falls_back_safely() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Loop Loop + ---@alias Forward Loop + + ---@type Forward + local value + + Value = value + "#, + ); + + let value_ty = ws.expr_ty("Value"); + assert_eq!(ws.humanize_type(value_ty), "Forward"); + } + #[test] fn test_issue_787() { let mut ws = VirtualWorkspace::new(); @@ -768,6 +934,7 @@ mod test { .expect("Box generic params"); assert_eq!(box_params.len(), 1); assert_eq!(box_params[0].name.as_str(), "T"); + assert_eq!(box_params[0].tpl_id, Some(GenericTplId::Type(0))); let box_default = box_params[0] .default_type .clone() @@ -783,6 +950,7 @@ mod test { .expect("Optional generic params"); assert_eq!(optional_params.len(), 1); assert_eq!(optional_params[0].name.as_str(), "T"); + assert_eq!(optional_params[0].tpl_id, Some(GenericTplId::Type(0))); let optional_default = optional_params[0] .default_type .clone() diff --git a/crates/emmylua_code_analysis/src/compilation/test/mod.rs b/crates/emmylua_code_analysis/src/compilation/test/mod.rs index a6c8629fd..3b6807bde 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/mod.rs @@ -10,6 +10,7 @@ mod decl_test; mod diagnostic_disable_test; mod flow; mod for_range_var_infer_test; +mod generic_builtin; mod generic_infer_test; mod generic_test; mod infer_str_tpl_test; diff --git a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs index 1a66b2031..08978d8be 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs @@ -1,10 +1,11 @@ use smol_str::SmolStr; -use crate::{LuaAttributeUse, LuaType}; +use crate::{GenericTplId, LuaAttributeUse, LuaType}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GenericParam { pub name: SmolStr, + pub tpl_id: Option, pub type_constraint: Option, pub default_type: Option, pub attributes: Option>, @@ -19,9 +20,15 @@ impl GenericParam { ) -> Self { Self { name, + tpl_id: None, type_constraint, default_type, attributes, } } + + pub fn with_tpl_id(mut self, tpl_id: Option) -> Self { + self.tpl_id = tpl_id; + self + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index cc5594eff..c535f984a 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -1,6 +1,6 @@ use emmylua_parser::{ LuaAst, LuaAstNode, LuaCallExpr, LuaClosureExpr, LuaComment, LuaDocGenericDeclList, - LuaDocTagAlias, LuaDocTagClass, LuaDocTagGeneric, LuaDocTagType, LuaDocType, + LuaDocGenericType, LuaDocTagAlias, LuaDocTagClass, LuaDocTagGeneric, LuaDocTagType, LuaDocType, }; use rowan::TextRange; use smol_str::SmolStr; @@ -13,7 +13,7 @@ use crate::semantic::{ use crate::{ DiagnosticCode, DocTypeInferContext, GenericTplId, LuaArrayType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, - LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, + LuaTypeNode, LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, TypeSubstitutor, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, }; @@ -617,55 +617,142 @@ fn check_doc_tag_type( let type_list = doc_tag_type.get_type_list(); let doc_ctx = DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for doc_type in type_list { - let explicit_args = explicit_generic_args(&doc_type); - if explicit_args.is_empty() { - continue; - } + check_doc_type_generic_constraints(context, semantic_model, doc_ctx, &doc_type); + } + Some(()) +} - let type_ref = infer_doc_type(doc_ctx, &doc_type); - let generic_type = match type_ref { - LuaType::Generic(generic_type) => generic_type, - _ => continue, - }; +fn check_doc_type_generic_constraints( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + doc_ctx: DocTypeInferContext<'_>, + doc_type: &LuaDocType, +) -> Option<()> { + let LuaDocType::Generic(generic_doc_type) = doc_type else { + return Some(()); + }; + + let explicit_args = explicit_generic_args(generic_doc_type); + if explicit_args.is_empty() { + return Some(()); + } - let generic_params = semantic_model + let name = generic_doc_type.get_name_type()?.get_name_text()?; + let type_decl = semantic_model.get_db().get_type_index().find_type_decl( + semantic_model.get_file_id(), + &name, + semantic_model .get_db() - .get_type_index() - .get_generic_params(&generic_type.get_base_type_id())?; - for (i, param_type) in generic_type - .get_params() - .iter() - .take(explicit_args.len()) - .enumerate() - { - let extend_type = generic_params.get(i)?.type_constraint.clone()?; - let result = semantic_model.type_check_detail(&extend_type, param_type); - if result.is_err() { - add_type_check_diagnostic( - context, - semantic_model, - explicit_args.get(i)?.get_range(), - &extend_type, - param_type, - result, - ); + .resolve_workspace_id(semantic_model.get_file_id()), + )?; + let type_id = type_decl.get_id(); + let generic_params = semantic_model + .get_db() + .get_type_index() + .get_generic_params(&type_id)?; + + let instantiate_arg = explicit_arg_instantiation_flags(&generic_params, explicit_args.len()); + let empty_substitutor = TypeSubstitutor::new(); + let param_types = explicit_args + .iter() + .enumerate() + .map(|(idx, doc_type)| { + let ty = infer_doc_type(doc_ctx, doc_type); + if instantiate_arg.get(idx).copied().unwrap_or(false) { + instantiate_type_generic(semantic_model.get_db(), &ty, &empty_substitutor) + } else { + ty } + }) + .collect::>(); + + let substitutor = + TypeSubstitutor::from_alias(semantic_model.get_db(), param_types.clone(), type_id); + + for (i, param_type) in param_types.iter().enumerate() { + let Some(explicit_arg) = explicit_args.get(i) else { + continue; + }; + let Some(extend_type) = generic_params + .get(i) + .and_then(|param| param.type_constraint.clone()) + else { + continue; + }; + + let mut extend_type = + instantiate_type_generic(semantic_model.get_db(), &extend_type, &substitutor); + extend_type = normalize_keyof_any_constraint(extend_type); + let result = semantic_model.type_check_detail(&extend_type, param_type); + if result.is_err() { + add_type_check_diagnostic( + context, + semantic_model, + explicit_arg.get_range(), + &extend_type, + param_type, + result, + ); } } + Some(()) } -fn explicit_generic_args(doc_type: &LuaDocType) -> Vec { - let LuaDocType::Generic(generic_doc_type) = doc_type else { - return Vec::new(); - }; - +fn explicit_generic_args(generic_doc_type: &LuaDocGenericType) -> Vec { generic_doc_type .get_generic_types() .map(|type_list| type_list.get_types().collect()) .unwrap_or_default() } +fn explicit_arg_instantiation_flags( + generic_params: &[crate::GenericParam], + explicit_arg_count: usize, +) -> Vec { + let mut flags = vec![false; explicit_arg_count]; + for (constraint_index, param) in generic_params.iter().enumerate().take(explicit_arg_count) { + let Some(constraint) = param.type_constraint.as_ref() else { + continue; + }; + + flags[constraint_index] = true; + for (arg_index, referenced_param) in + generic_params.iter().enumerate().take(explicit_arg_count) + { + let tpl_id = referenced_param + .tpl_id + .unwrap_or(GenericTplId::Type(arg_index as u32)); + if type_contains_tpl_ref(constraint, tpl_id) { + flags[arg_index] = true; + } + } + } + + flags +} + +fn type_contains_tpl_ref(ty: &LuaType, tpl_id: GenericTplId) -> bool { + ty.any_type(|ty| match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::StrTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + _ => false, + }) +} + +fn normalize_keyof_any_constraint(ty: LuaType) -> LuaType { + match ty { + LuaType::Call(alias_call) + if alias_call.get_call_kind() == crate::LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 + && alias_call.get_operands()[0].is_any() => + { + LuaType::from_vec(vec![LuaType::String, LuaType::Integer, LuaType::Number]) + } + _ => ty, + } +} + #[allow(clippy::too_many_arguments)] fn check_param( context: &mut DiagnosticContext, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index d2aadf06d..d12d8b3fb 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -1,23 +1,24 @@ use hashbrown::{HashMap, HashSet}; -use std::ops::Deref; +use internment::ArcIntern; use crate::{ - DbIndex, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, + DbIndex, GenericTpl, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, check_type_compact, db_index::{LuaObjectType, LuaTupleType, LuaType}, semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; use super::{ - get_default_constructor, instantiate_type_generic, instantiate_type_generic_with_context, + TplCandidateSource, finalize_inferred_tpl_candidate, get_default_constructor, + instantiate_type_generic_inner, +}; +use crate::semantic::generic::type_substitutor::{ + GenericInstantiateContext, GenericInstantiateFrame, TplBinding, }; -use crate::semantic::generic::type_substitutor::{GenericInstantiateContext, TplBinding}; #[derive(Debug, Clone, Copy)] enum InferVariance { - // 协变 Covariant, - // 逆变 Contravariant, } @@ -38,27 +39,35 @@ struct InferCandidateSet { pub(super) fn instantiate_conditional( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> LuaType { - if let Some(distributed) = instantiate_distributed_conditional(context, conditional) { + let Some(frame) = frame.enter() else { + return instantiate_conditional_residual(context, frame, conditional, None, None); + }; + + if let Some(distributed) = instantiate_distributed_conditional(context, frame, conditional) { return distributed; } - instantiate_conditional_once(context, conditional) + instantiate_conditional_once(context, frame, conditional) } fn instantiate_conditional_once( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> LuaType { let left_type = instantiate_conditional_operand( context, + frame, conditional.get_checked_type(), true, conditional.has_new, ); let right_type = instantiate_conditional_operand( context, + frame, conditional.get_extends_type(), false, conditional.has_new, @@ -78,115 +87,165 @@ fn instantiate_conditional_once( ) { instantiate_true_branch( context, + frame, conditional, - finalize_infer_assignments(infer_assignments), + finalize_infer_assignments(context, conditional, infer_assignments), ) - } else { - instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, + } else if is_deferred_conditional_operand(&left_type) + || right_type.any_type(|inner| match inner { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + !tpl.get_tpl_id().is_conditional_infer() + } + LuaType::StrTplRef(_) + | LuaType::SelfInfer + | LuaType::Conditional(_) + | LuaType::Mapped(_) + | LuaType::Call(_) => true, + _ => false, + }) + { + instantiate_conditional_residual( + context, + frame, + conditional, + Some(left_type), + Some(right_type), ) + } else { + instantiate_type_generic_inner(context, frame, conditional.get_false_type()) }; } match check_conditional_extends(context.db, &left_type, &right_type) { - ConditionalCheck::True => instantiate_true_branch(context, conditional, HashMap::new()), - ConditionalCheck::False => instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, - ), + ConditionalCheck::True => { + instantiate_true_branch(context, frame, conditional, HashMap::new()) + } + ConditionalCheck::False => { + instantiate_type_generic_inner(context, frame, conditional.get_false_type()) + } ConditionalCheck::Both => { - let true_type = instantiate_true_branch(context, conditional, HashMap::new()); - let false_type = instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, - ); + if is_deferred_conditional_operand(&left_type) + || is_deferred_conditional_operand(&right_type) + { + return instantiate_conditional_residual( + context, + frame, + conditional, + Some(left_type), + Some(right_type), + ); + } + let true_type = instantiate_true_branch(context, frame, conditional, HashMap::new()); + let false_type = + instantiate_type_generic_inner(context, frame, conditional.get_false_type()); TypeOps::Union.apply(context.db, &true_type, &false_type) } } } +fn instantiate_conditional_residual( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + conditional: &LuaConditionalType, + checked_type: Option, + extends_type: Option, +) -> LuaType { + let instantiate_branch = |branch: &LuaType| { + if branch.any_type(|ty| match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + context.substitutor.get(tpl.get_tpl_id()).is_some() + } + LuaType::SelfInfer => context.substitutor.get_self_type().is_some(), + _ => false, + }) { + instantiate_type_generic_inner(context, frame, branch) + } else { + branch.clone() + } + }; + + LuaType::Conditional( + LuaConditionalType::new( + checked_type.unwrap_or_else(|| { + instantiate_type_generic_inner(context, frame, conditional.get_checked_type()) + }), + extends_type.unwrap_or_else(|| { + instantiate_type_generic_inner(context, frame, conditional.get_extends_type()) + }), + instantiate_branch(conditional.get_true_type()), + instantiate_branch(conditional.get_false_type()), + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ) +} + /// 处理分布式条件类型, 与`TS`中的分布式条件类型处理方式相同, 只有裸模版参数才会被分布式. fn instantiate_distributed_conditional( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> Option { - let tpl_id = naked_checked_type_tpl_id(conditional.get_checked_type())?; + let tpl_id = match conditional.get_checked_type() { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { + tpl.get_tpl_id() + } + _ => return None, + }; let raw_checked_type = context.substitutor.get_raw_type(tpl_id)?; if raw_checked_type.is_never() { return Some(LuaType::Never); } - let members = union_members(raw_checked_type)?; + let members = match &raw_checked_type { + LuaType::Union(union) => union.into_vec(), + LuaType::MultiLineUnion(multi) => multi + .get_unions() + .iter() + .map(|(member, _)| member.clone()) + .collect(), + _ => return None, + }; let mut result = LuaType::Never; for member in members { let mut member_substitutor = context.substitutor.clone(); member_substitutor.bind(tpl_id, TplBinding::ReplaceConstType(member)); let member_context = context.with_substitutor(&member_substitutor); - let member_result = instantiate_conditional_once(&member_context, conditional); + let member_result = instantiate_conditional_once(&member_context, frame, conditional); result = TypeOps::Union.apply(context.db, &result, &member_result); } Some(result) } -fn naked_checked_type_tpl_id(checked_type: &LuaType) -> Option { - match checked_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { - Some(tpl.get_tpl_id()) - } - _ => None, - } -} - -fn union_members(ty: &LuaType) -> Option> { - match ty { - LuaType::Union(union) => Some(union.into_vec()), - LuaType::MultiLineUnion(multi) => Some( - multi - .get_unions() - .iter() - .map(|(member, _)| member.clone()) - .collect(), - ), - _ => None, - } -} - fn instantiate_true_branch( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, infer_assignments: HashMap, ) -> LuaType { if infer_assignments.is_empty() { - return instantiate_type_generic( - context.db, - conditional.get_true_type(), - context.substitutor, - ); + return instantiate_type_generic_inner(context, frame, conditional.get_true_type()); } let mut true_substitutor = context.substitutor.clone(); for (tpl_id, ty) in infer_assignments { true_substitutor.bind(tpl_id, TplBinding::ConditionalInferType(ty)); } - instantiate_type_generic(context.db, conditional.get_true_type(), &true_substitutor) + let true_context = context.with_substitutor(&true_substitutor); + instantiate_type_generic_inner(&true_context, frame, conditional.get_true_type()) } fn contains_conditional_infer(ty: &LuaType) -> bool { - ty.any_type(conditional_infer_tpl_id) -} - -fn conditional_infer_tpl_id(ty: &LuaType) -> bool { - matches!( - ty, - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if tpl.get_tpl_id().is_conditional_infer() - ) + ty.any_type(|inner| { + matches!( + inner, + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if tpl.get_tpl_id().is_conditional_infer() + ) + }) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -217,6 +276,10 @@ fn check_conditional_extends(db: &DbIndex, source: &LuaType, target: &LuaType) - return ConditionalCheck::True; } + if literal_extends_base_type(source, target) { + return ConditionalCheck::True; + } + if let LuaType::Union(union) = source { let mut result = ConditionalCheck::False; for member in union.into_vec() { @@ -241,6 +304,10 @@ fn check_conditional_extends(db: &DbIndex, source: &LuaType, target: &LuaType) - return ConditionalCheck::False; } + if is_deferred_conditional_operand(source) || is_deferred_conditional_operand(target) { + return ConditionalCheck::Both; + } + if check_type_compact_with_level( db, source, @@ -263,6 +330,25 @@ fn merge_conditional_check(left: ConditionalCheck, right: ConditionalCheck) -> C } } +fn literal_extends_base_type(source: &LuaType, target: &LuaType) -> bool { + matches!( + (source, target), + ( + LuaType::StringConst(_) | LuaType::DocStringConst(_), + LuaType::String + ) | ( + LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_), + LuaType::Integer + ) | ( + LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_) | LuaType::FloatConst(_), + LuaType::Number, + ) | ( + LuaType::BooleanConst(_) | LuaType::DocBooleanConst(_), + LuaType::Boolean + ) + ) +} + fn collect_infer_assignments( db: &DbIndex, source: &LuaType, @@ -645,6 +731,8 @@ fn insert_infer_assignment( } fn finalize_infer_assignments( + context: &GenericInstantiateContext, + conditional: &LuaConditionalType, assignments: HashMap, ) -> HashMap { assignments @@ -653,29 +741,51 @@ fn finalize_infer_assignments( candidates .covariant .or(candidates.contravariant) - .map(|ty| (tpl_id, ty)) + .map(|raw_candidate| { + let Some(param) = conditional.get_infer_params().get(tpl_id.get_idx()) else { + return (tpl_id, raw_candidate); + }; + + let tpl = GenericTpl::new( + tpl_id, + ArcIntern::new(param.name.clone()), + param.type_constraint.clone(), + param.default_type.clone(), + ); + ( + tpl_id, + finalize_inferred_tpl_candidate( + context.db, + &tpl, + &raw_candidate, + TplCandidateSource::ConstPreserving, + true, + true, + context.substitutor, + ), + ) + }) }) .collect() } fn instantiate_conditional_operand( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, operand: &LuaType, checked: bool, has_new: bool, ) -> LuaType { - let mut result = instantiate_type_generic_with_context(context, operand); + let mut result = instantiate_type_generic_inner(context, frame, operand); if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { result = raw.clone(); - } else if checked && result.contains_tpl_node() { - result = LuaType::Unknown; + } else if checked && result.is_never() { + result = LuaType::Never; } } - result = actualize_unresolved_templates(result); - if has_new && let LuaType::Ref(id) | LuaType::Def(id) = &result && let Some(decl) = context.db.get_type_index().get_type_decl(id) @@ -688,147 +798,17 @@ fn instantiate_conditional_operand( result } -// 条件类型判定只消费已经实例化后的实际类型, 残留的普通模板引用在这里递归收敛为 `unknown`. -// `infer` pattern 也以模板引用表示, 必须保留下来供后续结构匹配绑定. -fn actualize_unresolved_templates(ty: LuaType) -> LuaType { - match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_conditional_infer() { - // Conditional infer 是右侧 pattern 的占位孔, 不能像普通未解模板一样抹成 unknown. - LuaType::TplRef(tpl) - } else { - LuaType::Unknown - } - } - LuaType::StrTplRef(_) => LuaType::Unknown, - LuaType::Array(array) => LuaType::Array( - crate::LuaArrayType::new( - actualize_unresolved_templates(array.get_base().clone()), - array.get_len().clone(), - ) - .into(), - ), - LuaType::Tuple(tuple) => LuaType::Tuple( - LuaTupleType::new( - tuple - .get_types() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - tuple.status, - ) - .into(), - ), - LuaType::DocFunction(func) => LuaType::DocFunction( - crate::LuaFunctionType::new( - func.get_async_state(), - func.is_colon_define(), - func.is_variadic(), - func.get_params() - .iter() - .map(|(name, ty)| { - (name.clone(), ty.clone().map(actualize_unresolved_templates)) - }) - .collect(), - actualize_unresolved_templates(func.get_ret().clone()), - ) - .into(), - ), - LuaType::Object(object) => LuaType::Object( - LuaObjectType::new_with_fields( - object - .get_fields() - .iter() - .map(|(key, ty)| (key.clone(), actualize_unresolved_templates(ty.clone()))) - .collect(), - object - .get_index_access() - .iter() - .map(|(key, value)| { - ( - actualize_unresolved_templates(key.clone()), - actualize_unresolved_templates(value.clone()), - ) - }) - .collect(), - ) - .into(), - ), - LuaType::Union(union) => LuaType::from_vec( - union - .into_vec() - .into_iter() - .map(actualize_unresolved_templates) - .collect(), - ), - LuaType::MultiLineUnion(multi) => LuaType::from_vec( - multi - .get_unions() - .iter() - .map(|(ty, _)| actualize_unresolved_templates(ty.clone())) - .collect(), - ), - LuaType::Intersection(intersection) => LuaType::Intersection( - crate::LuaIntersectionType::new( - intersection - .get_types() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ) - .into(), - ), - LuaType::Generic(generic) => LuaType::Generic( - crate::LuaGenericType::new( - generic.get_base_type_id(), - generic - .get_params() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ) - .into(), - ), - LuaType::TableGeneric(params) => LuaType::TableGeneric( - params - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect::>() - .into(), - ), - LuaType::Variadic(variadic) => LuaType::Variadic( - match variadic.deref() { - crate::VariadicType::Base(base) => { - crate::VariadicType::Base(actualize_unresolved_templates(base.clone())) - } - crate::VariadicType::Multi(types) => crate::VariadicType::Multi( - types - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ), - } - .into(), - ), - LuaType::TypeGuard(guard) => { - LuaType::TypeGuard(actualize_unresolved_templates(guard.deref().clone()).into()) - } - LuaType::Conditional(conditional) => LuaType::Conditional( - LuaConditionalType::new( - actualize_unresolved_templates(conditional.get_checked_type().clone()), - actualize_unresolved_templates(conditional.get_extends_type().clone()), - actualize_unresolved_templates(conditional.get_true_type().clone()), - actualize_unresolved_templates(conditional.get_false_type().clone()), - conditional.get_infer_params().to_vec(), - conditional.has_new, - ) - .into(), - ), - ty => ty, - } +fn is_deferred_conditional_operand(ty: &LuaType) -> bool { + ty.any_type(|inner| { + matches!( + inner, + LuaType::TplRef(_) + | LuaType::ConstTplRef(_) + | LuaType::StrTplRef(_) + | LuaType::SelfInfer + | LuaType::Conditional(_) + | LuaType::Mapped(_) + | LuaType::Call(_) + ) + }) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs new file mode 100644 index 000000000..21a9ad912 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs @@ -0,0 +1,248 @@ +use hashbrown::{HashMap, HashSet}; +use std::ops::Deref; + +use crate::{ + GenericParam, LuaAliasCallKind, LuaMappedType, LuaMemberKey, LuaObjectType, LuaTupleStatus, + LuaTupleType, LuaType, TypeOps, VariadicType, +}; + +use super::{ + GenericInstantiateContext, GenericInstantiateFrame, instantiate_special_generic, + instantiate_type_generic_inner, key_type_to_member_key, +}; +use crate::semantic::generic::type_substitutor::TplBinding; + +pub(super) fn instantiate_mapped_type( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + mapped: &LuaMappedType, +) -> LuaType { + let Some(frame) = frame.enter() else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let Some(constraint) = mapped.param.1.type_constraint.as_ref() else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let Some(key_domain) = resolve_mapped_key_domain(context, frame, constraint) else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let empty_object = + || LuaType::Object(LuaObjectType::new_with_fields(HashMap::new(), Vec::new()).into()); + + if key_domain.keys.is_empty() { + return empty_object(); + } + + let key_count = key_domain.keys.len(); + let mut visited = HashSet::with_capacity(key_count); + let mut field_indices: HashMap = HashMap::with_capacity(key_count); + let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::with_capacity(key_count); + let mut index_access: Vec<(LuaType, LuaType)> = Vec::with_capacity(key_count); + let mut local_substitutor = context.substitutor.clone(); + + for key_ty in key_domain.keys { + if !visited.insert(key_ty.clone()) { + continue; + } + + local_substitutor.bind(mapped.param.0, TplBinding::ReplaceConstType(key_ty.clone())); + let local_context = context.with_substitutor(&local_substitutor); + let mut value_ty = instantiate_type_generic_inner(&local_context, frame, &mapped.value); + if mapped.is_optional { + value_ty = TypeOps::Union.apply(context.db, &value_ty, &LuaType::Nil); + } + + if let Some(member_key) = key_type_to_member_key(&key_ty) { + if let Some(index) = field_indices.get(&member_key).copied() { + let (_, existing) = &mut fields[index]; + let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); + *existing = merged; + } else { + field_indices.insert(member_key.clone(), fields.len()); + fields.push((member_key, value_ty)); + } + } else { + index_access.push((key_ty, value_ty)); + } + } + + if fields.is_empty() && index_access.is_empty() { + return empty_object(); + } + + if key_domain.tuple_like + && index_access.is_empty() + && let Some(types) = mapped_tuple_types(&fields) + { + return LuaType::Tuple(LuaTupleType::new(types, LuaTupleStatus::InferResolve).into()); + } + + let field_map: HashMap = fields.into_iter().collect(); + LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()) +} + +struct MappedKeyDomain { + keys: Vec, + tuple_like: bool, +} + +fn resolve_mapped_key_domain( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + constraint: &LuaType, +) -> Option { + if let LuaType::Call(alias_call) = constraint + && alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 + { + let source = instantiate_type_generic_inner(context, frame, &alias_call.get_operands()[0]); + let keys = instantiate_special_generic::get_keyof_type(context.db, &source)?; + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&keys, &mut atoms) { + return None; + } + return Some(MappedKeyDomain { + keys: atoms, + tuple_like: source.is_tuple() || matches!(source, LuaType::Variadic(_)), + }); + } + + let instantiated = instantiate_type_generic_inner(context, frame, constraint); + match &instantiated { + LuaType::Call(alias_call) + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 => + { + let source = &alias_call.get_operands()[0]; + let keys = instantiate_special_generic::get_keyof_type(context.db, source)?; + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&keys, &mut atoms) { + return None; + } + Some(MappedKeyDomain { + keys: atoms, + tuple_like: source.is_tuple() || matches!(source, LuaType::Variadic(_)), + }) + } + _ => { + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&instantiated, &mut atoms) { + return None; + } + Some(MappedKeyDomain { + tuple_like: instantiated.is_tuple(), + keys: atoms, + }) + } + } +} + +fn instantiate_mapped_residual( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + mapped: &LuaMappedType, +) -> LuaType { + let param = ( + mapped.param.0, + GenericParam::new( + mapped.param.1.name.clone(), + mapped + .param + .1 + .type_constraint + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, frame, ty)), + mapped + .param + .1 + .default_type + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, frame, ty)), + mapped.param.1.attributes.clone(), + ), + ); + + LuaType::Mapped( + LuaMappedType::new( + param, + instantiate_type_generic_inner(context, frame, &mapped.value), + mapped.is_readonly, + mapped.is_optional, + ) + .into(), + ) +} + +fn mapped_tuple_types(fields: &[(LuaMemberKey, LuaType)]) -> Option> { + let mut indexed = fields + .iter() + .filter_map(|(key, ty)| match key { + LuaMemberKey::Integer(i) => Some((*i, ty.clone())), + _ => None, + }) + .collect::>(); + + if indexed.len() != fields.len() { + return None; + } + + indexed.sort_by_key(|(index, _)| *index); + let starts_at_zero = indexed.first().is_some_and(|(index, _)| *index == 0); + let expected_start = if starts_at_zero { 0 } else { 1 }; + for (offset, (index, _)) in indexed.iter().enumerate() { + if *index != expected_start + offset as i64 { + return None; + } + } + + Some(indexed.into_iter().map(|(_, ty)| ty).collect()) +} + +fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) -> bool { + match key_ty { + LuaType::Union(union) => { + for member in union.into_vec() { + if !collect_mapped_key_atoms(&member, acc) { + return false; + } + } + true + } + LuaType::MultiLineUnion(multi) => { + for (member, _) in multi.get_unions() { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), + VariadicType::Multi(types) => { + for member in types { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + }, + LuaType::Tuple(tuple) => { + for member in tuple.get_types() { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + LuaType::Never => true, + LuaType::Unknown | LuaType::Call(_) | LuaType::Mapped(_) => false, + _ => { + acc.push(key_ty.clone()); + true + } + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index c44efafa5..80f10166f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -1,6 +1,6 @@ use crate::{ DbIndex, LuaAliasCallKind, LuaAliasCallType, LuaMemberInfo, LuaMemberKey, LuaObjectType, - LuaTupleStatus, LuaTupleType, LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, + LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, semantic::{ generic::key_type_to_member_key, member::{find_members, infer_raw_member_type}, @@ -11,18 +11,19 @@ use hashbrown::HashMap; use std::{ops::Deref, vec}; use super::{ - GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, - instantiate_type_generic_with_context, + GenericInstantiateContext, GenericInstantiateFrame, SubstitutorValue, TypeSubstitutor, + instantiate_type_generic_inner, }; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, alias_call: &LuaAliasCallType, ) -> LuaType { let operand_exprs = alias_call.get_operands(); let operands = operand_exprs .iter() - .map(|it| instantiate_type_generic_with_context(context, it)) + .map(|it| instantiate_type_generic_inner(context, frame, it)) .collect::>(); match alias_call.get_call_kind() { @@ -45,16 +46,12 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let members = get_keyof_members(context.db, &operands[0]).unwrap_or_default(); - let member_key_types = members - .iter() - .filter_map(|m| match &m.key { - LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), - LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), - _ => None, - }) - .collect::>(); - LuaType::Tuple(LuaTupleType::new(member_key_types, LuaTupleStatus::InferResolve).into()) + match get_keyof_type(context.db, &operands[0]) { + Some(key_type) => key_type, + None => { + LuaType::Call(LuaAliasCallType::new(LuaAliasCallKind::KeyOf, operands).into()) + } + } } // 条件类型不在此处理 LuaAliasCallKind::Extends => { @@ -80,7 +77,7 @@ pub(super) fn instantiate_alias_call( instantiate_select_call(&operands[0], &operands[1]) } LuaAliasCallKind::Unpack => { - let operands = resolve_unpack_operands(context, operand_exprs); + let operands = resolve_unpack_operands(context, frame, operand_exprs); instantiate_unpack_call(context.db, &operands) } LuaAliasCallKind::RawGet => { @@ -107,6 +104,29 @@ pub(super) fn instantiate_alias_call( } } +pub(super) fn get_keyof_type(db: &DbIndex, ty: &LuaType) -> Option { + let members = get_keyof_members(db, ty)?; + let member_key_types = members + .iter() + .filter_map(|m| match &m.key { + LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), + LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), + LuaMemberKey::ExprType(typ) => Some(typ.clone()), + _ => None, + }) + .collect::>(); + + if member_key_types.is_empty() { + if members.is_empty() { + return Some(LuaType::Never); + } + + return None; + } + + Some(LuaType::from_vec(member_key_types)) +} + fn instantiate_merge_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.len() != 2 { return LuaType::Unknown; @@ -223,6 +243,7 @@ fn instantiate_select_call(source: &LuaType, index: &LuaType) -> LuaType { fn resolve_unpack_operands( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, operand_exprs: &[LuaType], ) -> Vec { operand_exprs @@ -230,7 +251,7 @@ fn resolve_unpack_operands( .enumerate() .map(|(index, operand)| { if index != 0 { - return instantiate_type_generic_with_context(context, operand); + return instantiate_type_generic_inner(context, frame, operand); } let raw = match operand { LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => context @@ -238,9 +259,12 @@ fn resolve_unpack_operands( .get(tpl_ref.get_tpl_id()) .and_then(|value| match value { SubstitutorValue::None => None, - SubstitutorValue::Type(ty) => Some(ty.raw().clone()), - SubstitutorValue::MultiTypes { raw_types, .. } => Some(LuaType::Variadic( - VariadicType::Multi(raw_types.clone()).into(), + SubstitutorValue::Type { value, .. } => Some(value.raw().clone()), + SubstitutorValue::MultiTypes { values, .. } => Some(LuaType::Variadic( + VariadicType::Multi( + values.iter().map(|value| value.raw().clone()).collect(), + ) + .into(), )), SubstitutorValue::Params(params) => Some( params @@ -254,7 +278,7 @@ fn resolve_unpack_operands( }), _ => None, }; - raw.unwrap_or_else(|| instantiate_type_generic_with_context(context, operand)) + raw.unwrap_or_else(|| instantiate_type_generic_inner(context, frame, operand)) }) .collect() } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index 571244289..94612aaee 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -2,15 +2,15 @@ mod complete_generic_args; mod infer_call_func_generic; mod inference_widening; mod instantiate_conditional_generic; +mod instantiate_mapped_type; mod instantiate_special_generic; -use hashbrown::{HashMap, HashSet}; +use hashbrown::HashMap; use std::ops::Deref; use crate::{ - DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, - LuaOperatorMetaMethod, LuaSignatureId, LuaTupleStatus, LuaTupleType, LuaTypeDeclId, - LuaTypeNode, TypeOps, + DbIndex, GenericTpl, LuaArrayType, LuaMemberKey, LuaOperatorMetaMethod, LuaSignatureId, + LuaTupleType, LuaTypeDeclId, LuaTypeNode, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaType, LuaUnionType, VariadicType, @@ -18,7 +18,8 @@ use crate::{ }; use super::type_substitutor::{ - GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, UninferredTplPolicy, + GenericInstantiateContext, GenericInstantiateFrame, SubstitutorValue, TypeSubstitutor, + UninferredTplPolicy, }; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, @@ -28,6 +29,7 @@ pub(in crate::semantic::generic) use inference_widening::{ TplCandidateSource, finalize_inferred_tpl_candidate, }; pub use inference_widening::{WideningContext, widen_type_with_context}; +use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_inner; pub use instantiate_special_generic::get_keyof_members; pub fn instantiate_type_generic( @@ -36,35 +38,46 @@ pub fn instantiate_type_generic( substitutor: &TypeSubstitutor, ) -> LuaType { let context = GenericInstantiateContext::new(db, substitutor); + let frame = context.root_frame(); match ty { - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context(&context, doc_func), - _ => instantiate_type_generic_with_context(&context, ty), + LuaType::DocFunction(doc_func) => instantiate_doc_function(&context, frame, doc_func), + _ => instantiate_type_generic_inner(&context, frame, ty), } } -pub(super) fn instantiate_type_generic_with_context( +pub(super) fn instantiate_type_generic_inner( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, ty: &LuaType, ) -> LuaType { + let Some(frame) = frame.enter() else { + return ty.clone(); + }; + match ty { - LuaType::Array(array_type) => instantiate_array(context, array_type.get_base()), - LuaType::Tuple(tuple) => instantiate_tuple(context, tuple), - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context( - &context.with_policy(UninferredTplPolicy::PreserveTplRef), + LuaType::Array(array_type) => instantiate_array(context, frame, array_type.get_base()), + LuaType::Tuple(tuple) => instantiate_tuple(context, frame, tuple), + LuaType::DocFunction(doc_func) => instantiate_doc_function( + context, + frame.with_policy(UninferredTplPolicy::PreserveTplRef), doc_func, ), - LuaType::Object(object) => instantiate_object(context, object), - LuaType::Union(union) => instantiate_union(context, union), - LuaType::Intersection(intersection) => instantiate_intersection(context, intersection), - LuaType::Generic(generic) => instantiate_generic_with_context(context, generic), - LuaType::TableGeneric(table_params) => instantiate_table_generic(context, table_params), - LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context), - LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context), - LuaType::Signature(sig_id) => instantiate_signature(context, sig_id), + LuaType::Object(object) => instantiate_object(context, frame, object), + LuaType::Union(union) => instantiate_union(context, frame, union), + LuaType::Intersection(intersection) => { + instantiate_intersection(context, frame, intersection) + } + LuaType::Generic(generic) => instantiate_generic(context, frame, generic), + LuaType::TableGeneric(table_params) => { + instantiate_table_generic(context, frame, table_params) + } + LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context, frame), + LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context, frame), + LuaType::Signature(sig_id) => instantiate_signature(context, frame, sig_id), LuaType::Call(alias_call) => { - instantiate_special_generic::instantiate_alias_call(context, alias_call) + instantiate_special_generic::instantiate_alias_call(context, frame, alias_call) } - LuaType::Variadic(variadic) => instantiate_variadic_type(context, variadic), + LuaType::Variadic(variadic) => instantiate_variadic_type(context, frame, variadic), LuaType::SelfInfer => { if let Some(typ) = context.substitutor.get_self_type() { typ.clone() @@ -73,29 +86,34 @@ pub(super) fn instantiate_type_generic_with_context( } } LuaType::TypeGuard(guard) => { - let inner = instantiate_type_generic_with_context(context, guard.deref()); + let inner = instantiate_type_generic_inner(context, frame, guard.deref()); LuaType::TypeGuard(inner.into()) } LuaType::Conditional(conditional) => { - instantiate_conditional_generic::instantiate_conditional(context, conditional) + instantiate_conditional_generic::instantiate_conditional(context, frame, conditional) } - LuaType::Mapped(mapped) => instantiate_mapped_type(context, mapped.deref()), + LuaType::Mapped(mapped) => instantiate_mapped_type_inner(context, frame, mapped.deref()), _ => ty.clone(), } } -fn instantiate_types<'a, I>(context: &GenericInstantiateContext, types: I) -> Vec +fn instantiate_types<'a, I>( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + types: I, +) -> Vec where I: IntoIterator, { types .into_iter() - .map(|ty| instantiate_type_generic_with_context(context, ty)) + .map(|ty| instantiate_type_generic_inner(context, frame, ty)) .collect() } fn instantiate_type_pairs<'a, I>( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, pairs: I, ) -> Vec<(LuaType, LuaType)> where @@ -105,19 +123,27 @@ where .into_iter() .map(|(key, value)| { ( - instantiate_type_generic_with_context(context, key), - instantiate_type_generic_with_context(context, value), + instantiate_type_generic_inner(context, frame, key), + instantiate_type_generic_inner(context, frame, value), ) }) .collect() } -fn instantiate_array(context: &GenericInstantiateContext, base: &LuaType) -> LuaType { - let base = instantiate_type_generic_with_context(context, base); +fn instantiate_array( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + base: &LuaType, +) -> LuaType { + let base = instantiate_type_generic_inner(context, frame, base); LuaType::Array(LuaArrayType::from_base_type(base).into()) } -fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) -> LuaType { +fn instantiate_tuple( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + tuple: &LuaTupleType, +) -> LuaType { let mut new_types = Vec::new(); for t in tuple.get_types() { if let LuaType::Variadic(inner) = t { @@ -127,18 +153,20 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => new_types - .push(instantiate_uninferred_tpl_fallback(tpl, context)), - SubstitutorValue::MultiTypes { types, .. } => { - for typ in types { - new_types.push(typ.clone()); - } - } + .push(instantiate_uninferred_tpl_fallback(tpl, context, frame)), SubstitutorValue::Params(params) => { for (_, ty) in params { new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => new_types.push(ty.resolved().clone()), + SubstitutorValue::MultiTypes { values, .. } => { + new_types.extend( + values.iter().map(|value| value.resolved().clone()), + ); + } + SubstitutorValue::Type { value, .. } => { + new_types.push(value.resolved().clone()) + } SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), } } else { @@ -152,14 +180,15 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) break; } - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, frame, t); new_types.push(t); } LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } -fn instantiate_doc_function_with_context( +fn instantiate_doc_function( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, doc_func: &LuaFunctionType, ) -> LuaType { let tpl_func_params = doc_func.get_params(); @@ -178,15 +207,16 @@ fn instantiate_doc_function_with_context( match origin_param_type { LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => match base { - LuaType::TplRef(tpl) => { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - let ty = instantiate_uninferred_tpl_fallback(tpl, context); + let ty = + instantiate_uninferred_tpl_fallback(tpl, context, frame); new_params.push((origin_param.0.clone(), Some(ty))); } - SubstitutorValue::Type(ty) => { - let resolved_type = ty.resolved().clone(); + SubstitutorValue::Type { value, .. } => { + let resolved_type = value.resolved().clone(); // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple @@ -217,10 +247,11 @@ fn instantiate_doc_function_with_context( new_params.push(param.clone()); } } - SubstitutorValue::MultiTypes { types, .. } => { - for (i, typ) in types.iter().enumerate() { + SubstitutorValue::MultiTypes { values, .. } => { + for (i, value) in values.iter().enumerate() { let param_name = format!("var{}", i); - new_params.push((param_name, Some(typ.clone()))); + new_params + .push((param_name, Some(value.resolved().clone()))); } } _ => { @@ -238,7 +269,7 @@ fn instantiate_doc_function_with_context( } } LuaType::Generic(generic) => { - let new_type = instantiate_generic_with_context(context, generic); + let new_type = instantiate_generic(context, frame, generic); // 如果是 rest 参数且实例化后的类型是 tuple, 那么我们将展开 tuple if let LuaType::Tuple(tuple_type) = &new_type { let base_index = new_params.len(); @@ -256,13 +287,13 @@ fn instantiate_doc_function_with_context( VariadicType::Multi(_) => (), }, _ => { - let new_type = instantiate_type_generic_with_context(context, origin_param_type); + let new_type = instantiate_type_generic_inner(context, frame, origin_param_type); new_params.push((origin_param.0.clone(), Some(new_type))); } } } - let mut inst_ret_type = instantiate_type_generic_with_context(context, tpl_ret); + let mut inst_ret_type = instantiate_type_generic_inner(context, frame, tpl_ret); // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple if let LuaType::Variadic(_) = &&tpl_ret && let LuaType::Tuple(tuple) = &inst_ret_type @@ -300,43 +331,57 @@ fn instantiate_doc_function_with_context( ) } -fn instantiate_object(context: &GenericInstantiateContext, object: &LuaObjectType) -> LuaType { +fn instantiate_object( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + object: &LuaObjectType, +) -> LuaType { let new_fields = object .get_fields() .iter() .map(|(key, field)| { ( key.clone(), - instantiate_type_generic_with_context(context, field), + instantiate_type_generic_inner(context, frame, field), ) }) .collect::>(); - let new_index_access = instantiate_type_pairs(context, object.get_index_access().iter()); + let new_index_access = instantiate_type_pairs(context, frame, object.get_index_access().iter()); LuaType::Object(LuaObjectType::new_with_fields(new_fields, new_index_access).into()) } -fn instantiate_union(context: &GenericInstantiateContext, union: &LuaUnionType) -> LuaType { - LuaType::from_vec(instantiate_types(context, union.into_vec().iter())) +fn instantiate_union( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + union: &LuaUnionType, +) -> LuaType { + LuaType::from_vec(instantiate_types(context, frame, union.into_vec().iter())) } fn instantiate_intersection( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, intersection: &LuaIntersectionType, ) -> LuaType { LuaType::Intersection( - LuaIntersectionType::new(instantiate_types(context, intersection.get_types().iter())) - .into(), + LuaIntersectionType::new(instantiate_types( + context, + frame, + intersection.get_types().iter(), + )) + .into(), ) } -fn instantiate_generic_with_context( +fn instantiate_generic( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, generic: &LuaGenericType, ) -> LuaType { let generic_params = generic.get_params(); - let new_params = instantiate_types(context, generic_params.iter()); + let new_params = instantiate_types(context, frame, generic_params.iter()); let base = generic.get_base_type(); let type_decl_id = if let LuaType::Ref(id) = base { @@ -349,9 +394,14 @@ fn instantiate_generic_with_context( && let Some(type_decl) = context.db.get_type_index().get_type_decl(&type_decl_id) && type_decl.is_alias() { - let new_substitutor = TypeSubstitutor::from_alias(new_params.clone(), type_decl_id.clone()); - if let Some(origin) = type_decl.get_alias_origin(context.db, Some(&new_substitutor)) { - return origin; + let Some(alias_context) = context.enter_alias(&type_decl_id) else { + return LuaType::Generic(LuaGenericType::new(type_decl_id, new_params).into()); + }; + let new_substitutor = + TypeSubstitutor::from_alias(context.db, new_params.clone(), type_decl_id.clone()); + let alias_context = alias_context.with_substitutor(&new_substitutor); + if let Some(origin) = type_decl.get_alias_ref() { + return instantiate_type_generic_inner(&alias_context, frame, origin); } } @@ -360,43 +410,57 @@ fn instantiate_generic_with_context( fn instantiate_table_generic( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, table_params: &[LuaType], ) -> LuaType { - LuaType::TableGeneric(instantiate_types(context, table_params.iter()).into()) + LuaType::TableGeneric(instantiate_types(context, frame, table_params.iter()).into()) } fn instantiate_uninferred_tpl_fallback( tpl: &GenericTpl, context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, ) -> LuaType { // 一些情况下需要保留 TplRef, 例如高阶函数调用 - if context.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { + if frame.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { return LuaType::TplRef(tpl.clone().into()); } // 显式默认值优先, 然后是 extends 约束, 最后才是 unknown. if let Some(default_type) = tpl.get_default_type() { - return instantiate_type_generic_with_context(context, default_type); + return instantiate_type_generic_inner(context, frame, default_type); } if let Some(constraint) = tpl.get_constraint() { - return instantiate_type_generic_with_context(context, constraint); + return instantiate_type_generic_inner(context, frame, constraint); } LuaType::Unknown } -fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { +fn instantiate_tpl_ref( + tpl: &GenericTpl, + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, +) -> LuaType { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); + return instantiate_uninferred_tpl_fallback(tpl, context, frame); } - SubstitutorValue::Type(ty) => { - return ty.resolved().clone(); + SubstitutorValue::Type { value, .. } => { + return value.resolved().clone(); } - SubstitutorValue::MultiTypes { types, .. } => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + SubstitutorValue::MultiTypes { values, .. } => { + return LuaType::Variadic( + VariadicType::Multi( + values + .iter() + .map(|value| value.resolved().clone()) + .collect(), + ) + .into(), + ); } SubstitutorValue::Params(params) => { return params @@ -413,17 +477,29 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType::TplRef(tpl.clone().into()) } -fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { +fn instantiate_const_tpl_ref( + tpl: &GenericTpl, + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, +) -> LuaType { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); + return instantiate_uninferred_tpl_fallback(tpl, context, frame); } - SubstitutorValue::Type(ty) => { - return ty.resolved().clone(); + SubstitutorValue::Type { value, .. } => { + return value.resolved().clone(); } - SubstitutorValue::MultiTypes { types, .. } => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + SubstitutorValue::MultiTypes { values, .. } => { + return LuaType::Variadic( + VariadicType::Multi( + values + .iter() + .map(|value| value.resolved().clone()) + .collect(), + ) + .into(), + ); } SubstitutorValue::Params(params) => { return params @@ -442,20 +518,22 @@ fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateConte fn instantiate_signature( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, signature_id: &LuaSignatureId, ) -> LuaType { if let Some(signature) = context.db.get_signature_index().get(signature_id) { let origin_type = { let fake_doc_function = signature.to_doc_func_type(); - instantiate_doc_function_with_context(context, &fake_doc_function) + instantiate_doc_function(context, frame, &fake_doc_function) }; if signature.overloads.is_empty() { return origin_type; } else { let mut result = Vec::new(); for overload in signature.overloads.iter() { - result.push(instantiate_doc_function_with_context( + result.push(instantiate_doc_function( context, + frame, &(*overload).clone(), )); } @@ -469,6 +547,7 @@ fn instantiate_signature( fn instantiate_variadic_type( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, variadic: &VariadicType, ) -> LuaType { match variadic { @@ -477,15 +556,15 @@ fn instantiate_variadic_type( if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - let fallback = instantiate_uninferred_tpl_fallback(tpl, context); + let fallback = instantiate_uninferred_tpl_fallback(tpl, context, frame); return match fallback { LuaType::Variadic(_) | LuaType::Never => fallback, LuaType::Nil | LuaType::Any | LuaType::Unknown => fallback, _ => LuaType::Variadic(VariadicType::Base(fallback).into()), }; } - SubstitutorValue::Type(ty) => { - let resolved_type = ty.resolved().clone(); + SubstitutorValue::Type { value, .. } => { + let resolved_type = value.resolved().clone(); if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never @@ -494,8 +573,16 @@ fn instantiate_variadic_type( } return LuaType::Variadic(VariadicType::Base(resolved_type).into()); } - SubstitutorValue::MultiTypes { types, .. } => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + SubstitutorValue::MultiTypes { values, .. } => { + return LuaType::Variadic( + VariadicType::Multi( + values + .iter() + .map(|value| value.resolved().clone()) + .collect(), + ) + .into(), + ); } SubstitutorValue::Params(params) => { let types = params @@ -513,7 +600,7 @@ fn instantiate_variadic_type( } } LuaType::Generic(generic) => { - return instantiate_generic_with_context(context, generic); + return instantiate_generic(context, frame, generic); } _ => {} }, @@ -521,7 +608,7 @@ fn instantiate_variadic_type( if types.iter().any(LuaTypeNode::contains_tpl_node) { let mut new_types = Vec::new(); for t in types { - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, frame, t); match t { LuaType::Never => {} LuaType::Variadic(variadic) => match variadic.deref() { @@ -543,92 +630,6 @@ fn instantiate_variadic_type( LuaType::Variadic(variadic.clone().into()) } -fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMappedType) -> LuaType { - let constraint = mapped - .param - .1 - .type_constraint - .as_ref() - .map(|ty| instantiate_type_generic_with_context(context, ty)); - - if let Some(constraint) = constraint { - let mut key_types = Vec::new(); - collect_mapped_key_atoms(&constraint, &mut key_types); - - let mut visited = HashSet::new(); - let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::new(); - let mut index_access: Vec<(LuaType, LuaType)> = Vec::new(); - - for key_ty in key_types { - if !visited.insert(key_ty.clone()) { - continue; - } - - let value_ty = instantiate_mapped_value(context, mapped, mapped.param.0, &key_ty); - - if let Some(member_key) = key_type_to_member_key(&key_ty) { - if let Some((_, existing)) = fields.iter_mut().find(|(key, _)| key == &member_key) { - let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); - *existing = merged; - } else { - fields.push((member_key, value_ty)); - } - } else { - index_access.push((key_ty, value_ty)); - } - } - - if !fields.is_empty() || !index_access.is_empty() { - // key 从 0 开始递增才被视为元组 - if constraint.is_tuple() { - let mut index = 0; - let mut is_tuple = true; - for (key, _) in &fields { - if let LuaMemberKey::Integer(i) = key { - if *i != index { - is_tuple = false; - break; - } - index += 1; - } else { - is_tuple = false; - break; - } - } - if is_tuple { - let types = fields.into_iter().map(|(_, ty)| ty).collect(); - return LuaType::Tuple( - LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(), - ); - } - } - let field_map: HashMap = fields.into_iter().collect(); - return LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()); - } - } - - instantiate_type_generic_with_context(context, &mapped.value) -} - -fn instantiate_mapped_value( - context: &GenericInstantiateContext, - mapped: &LuaMappedType, - tpl_id: GenericTplId, - replacement: &LuaType, -) -> LuaType { - let mut local_substitutor = context.substitutor.clone(); - local_substitutor.bind_type(tpl_id, replacement.clone()); - let local_context = context.with_substitutor(&local_substitutor); - let mut result = instantiate_type_generic_with_context(&local_context, &mapped.value); - // 根据 readonly 和 optional 属性进行处理 - if mapped.is_optional { - result = TypeOps::Union.apply(context.db, &result, &LuaType::Nil); - } - // TODO: 处理 readonly, 但目前 readonly 的实现存在问题, 这里我们先跳过 - - result -} - pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { match key_ty { LuaType::DocStringConst(s) => Some(LuaMemberKey::Name(s.deref().clone())), @@ -639,36 +640,6 @@ pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { } } -fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) { - match key_ty { - LuaType::Union(union) => { - for member in union.into_vec() { - collect_mapped_key_atoms(&member, acc); - } - } - LuaType::MultiLineUnion(multi) => { - for (member, _) in multi.get_unions() { - collect_mapped_key_atoms(member, acc); - } - } - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), - VariadicType::Multi(types) => { - for member in types { - collect_mapped_key_atoms(member, acc); - } - } - }, - LuaType::Tuple(tuple) => { - for member in tuple.get_types() { - collect_mapped_key_atoms(member, acc); - } - } - LuaType::Unknown | LuaType::Never => {} - _ => acc.push(key_ty.clone()), - } -} - pub(super) fn get_default_constructor(db: &DbIndex, decl_id: &LuaTypeDeclId) -> Option { let ids = db .get_operator_index() diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index a6c2751fe..75d498ca4 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -49,6 +49,7 @@ fn generic_tpl_pattern_match_inner( .ok_or(InferFailReason::None)?; if target_decl.is_alias() { let substitutor = TypeSubstitutor::from_alias( + context.db, target_generic.get_params().clone(), target_base.clone(), ); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 657e04897..9bf830086 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -656,12 +656,12 @@ fn param_type_list_pattern_match_type_list( let tpl_id = generic_tpl.get_tpl_id(); if let Some(inferred_type_value) = context.substitutor.get(tpl_id) { match inferred_type_value { - SubstitutorValue::Type(_) => { + SubstitutorValue::Type { .. } => { continue; } - SubstitutorValue::MultiTypes { types, .. } => { - if types.len() > 1 { - target_offset += types.len() - 1; + SubstitutorValue::MultiTypes { values, .. } => { + if values.len() > 1 { + target_offset += values.len() - 1; } continue; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 5f6f03be7..cfc5dfd22 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -4,6 +4,9 @@ use super::instantiate_type::{TplCandidateSource, finalize_inferred_tpl_candidat use crate::{DbIndex, GenericTpl, GenericTplId, LuaType, LuaTypeDeclId}; use std::sync::Arc; +const MAX_INSTANTIATION_DEPTH: usize = 128; +const MAX_ALIAS_STACK: usize = 32; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) enum UninferredTplPolicy { /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. @@ -34,7 +37,13 @@ pub(in crate::semantic::generic) enum TplBinding { pub struct GenericInstantiateContext<'a> { pub db: &'a DbIndex, pub substitutor: &'a TypeSubstitutor, + alias_stack: Arc<[LuaTypeDeclId]>, +} + +#[derive(Debug, Clone, Copy)] +pub(super) struct GenericInstantiateFrame { policy: UninferredTplPolicy, + depth: usize, } impl<'a> GenericInstantiateContext<'a> { @@ -42,32 +51,68 @@ impl<'a> GenericInstantiateContext<'a> { Self { db, substitutor, - policy: UninferredTplPolicy::Fallback, + alias_stack: Arc::from([]), } } - pub(super) fn with_policy(&self, policy: UninferredTplPolicy) -> GenericInstantiateContext<'a> { - GenericInstantiateContext { - db: self.db, - substitutor: self.substitutor, - policy, + pub(super) fn root_frame(&self) -> GenericInstantiateFrame { + GenericInstantiateFrame { + policy: UninferredTplPolicy::Fallback, + depth: 0, } } - pub fn with_substitutor<'b>( + pub(super) fn with_substitutor<'b>( &'b self, substitutor: &'b TypeSubstitutor, ) -> GenericInstantiateContext<'b> { GenericInstantiateContext { db: self.db, substitutor, - policy: self.policy, + alias_stack: self.alias_stack.clone(), + } + } + + pub(super) fn enter_alias( + &self, + alias_type_id: &LuaTypeDeclId, + ) -> Option> { + if self.alias_stack.len() >= MAX_ALIAS_STACK + || self.alias_stack.iter().any(|id| id == alias_type_id) + { + return None; } + + let mut alias_stack = Vec::with_capacity(self.alias_stack.len() + 1); + alias_stack.extend(self.alias_stack.iter().cloned()); + alias_stack.push(alias_type_id.clone()); + Some(GenericInstantiateContext { + db: self.db, + substitutor: self.substitutor, + alias_stack: Arc::from(alias_stack), + }) + } +} + +impl GenericInstantiateFrame { + pub(super) fn with_policy(self, policy: UninferredTplPolicy) -> Self { + Self { policy, ..self } } pub fn should_preserve_tpl_ref(&self) -> bool { self.policy == UninferredTplPolicy::PreserveTplRef } + + pub(super) fn enter(self) -> Option { + if self.depth >= MAX_INSTANTIATION_DEPTH { + return None; + } + + Some(Self { + depth: self.depth + 1, + ..self + }) + } } #[derive(Debug, Clone)] @@ -97,11 +142,11 @@ impl TypeSubstitutor { for (i, ty) in type_array.into_iter().enumerate() { tpl_replace_map.insert( GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new( - ty, - TplCandidateSource::Finalized, - true, - )), + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(ty, true), + source: TplCandidateSource::Finalized, + top_level: true, + }, ); } Self { @@ -111,18 +156,29 @@ impl TypeSubstitutor { } } - pub fn from_alias(type_array: Vec, alias_type_id: LuaTypeDeclId) -> Self { + pub fn from_alias( + db: &DbIndex, + type_array: Vec, + alias_type_id: LuaTypeDeclId, + ) -> Self { + let params = db.get_type_index().get_generic_params(&alias_type_id); + let mut tpl_replace_map = HashMap::new(); for (i, ty) in type_array.into_iter().enumerate() { + let tpl_id = params + .and_then(|params| params.get(i)) + .and_then(|param| param.tpl_id) + .unwrap_or(GenericTplId::Type(i as u32)); tpl_replace_map.insert( - GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new( - ty, - TplCandidateSource::Finalized, - true, - )), + tpl_id, + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(ty, true), + source: TplCandidateSource::Finalized, + top_level: true, + }, ); } + Self { tpl_replace_map, alias_type_id: Some(alias_type_id), @@ -163,11 +219,11 @@ impl TypeSubstitutor { self.tpl_replace_map.insert( tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new( - replace_type, - TplCandidateSource::ConstPreserving, - true, - )), + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(replace_type, false), + source: TplCandidateSource::ConstPreserving, + top_level: true, + }, ); } TplBinding::ReplaceConstType(replace_type) => { @@ -177,11 +233,11 @@ impl TypeSubstitutor { self.tpl_replace_map.insert( tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new( - replace_type, - TplCandidateSource::ConstPreserving, - true, - )), + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(replace_type, false), + source: TplCandidateSource::ConstPreserving, + top_level: true, + }, ); } binding => { @@ -191,18 +247,20 @@ impl TypeSubstitutor { } let value = match binding { - TplBinding::FinalizedType(replace_type) => { - SubstitutorValue::Type(SubstitutorTypeValue::new( - replace_type, - TplCandidateSource::Finalized, - true, - )) - } + TplBinding::FinalizedType(replace_type) => SubstitutorValue::Type { + value: SubstitutorTypeValue::new(replace_type, true), + source: TplCandidateSource::Finalized, + top_level: true, + }, TplBinding::InferredType { ty, source, top_level, - } => SubstitutorValue::Type(SubstitutorTypeValue::new(ty, source, top_level)), + } => SubstitutorValue::Type { + value: SubstitutorTypeValue::new(ty, false), + source, + top_level, + }, TplBinding::VariadicParams(params) => { let params = params .into_iter() @@ -215,8 +273,15 @@ impl TypeSubstitutor { source, top_level, } => SubstitutorValue::MultiTypes { - raw_types: types.clone(), - types, + values: types + .into_iter() + .map(|ty| { + SubstitutorTypeValue::new( + ty, + source == TplCandidateSource::Finalized, + ) + }) + .collect(), source, top_level, }, @@ -245,7 +310,7 @@ impl TypeSubstitutor { pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { match self.tpl_replace_map.get(&tpl_id) { - Some(SubstitutorValue::Type(ty)) => Some(ty.raw()), + Some(SubstitutorValue::Type { value, .. }) => Some(value.raw()), _ => None, } } @@ -259,44 +324,66 @@ impl TypeSubstitutor { for tpl in generic_tpls { let tpl_id = tpl.get_tpl_id(); let return_top_level = is_tpl_at_top_level(db, return_type, tpl_id); - let substitutor = self.clone(); - let Some(value) = self.tpl_replace_map.get_mut(&tpl_id) else { + let Some(value) = self.tpl_replace_map.get(&tpl_id) else { continue; }; - match value { - SubstitutorValue::Type(ty) => { - ty.finalize(db, tpl.as_ref(), return_top_level, &substitutor) - } - SubstitutorValue::MultiTypes { - raw_types, - types, + let finalized_value = match value { + SubstitutorValue::Type { + value, source, top_level, } => { - if *source == TplCandidateSource::Finalized { - continue; - } - let finalized = types - .iter() - .map(|ty| { - finalize_inferred_tpl_candidate( + if value.is_finalized() { + None + } else { + Some(SubstitutorValue::Type { + value: value.finalized( db, tpl.as_ref(), - ty, *source, *top_level, return_top_level, - &substitutor, - ) + self, + ), + source: TplCandidateSource::Finalized, + top_level: true, + }) + } + } + SubstitutorValue::MultiTypes { + values, + source, + top_level, + } => { + if *source == TplCandidateSource::Finalized { + None + } else { + let values = values + .iter() + .map(|value| { + value.finalized( + db, + tpl.as_ref(), + *source, + *top_level, + return_top_level, + self, + ) + }) + .collect(); + Some(SubstitutorValue::MultiTypes { + values, + source: TplCandidateSource::Finalized, + top_level: true, }) - .collect(); - *raw_types = types.clone(); - *types = finalized; - *source = TplCandidateSource::Finalized; - *top_level = true; + } } - _ => {} + _ => None, + }; + + if let Some(finalized_value) = finalized_value { + self.tpl_replace_map.insert(tpl_id, finalized_value); } } } @@ -324,20 +411,13 @@ impl TypeSubstitutor { pub struct SubstitutorTypeValue { raw: LuaType, finalized: Option, - source: TplCandidateSource, - top_level: bool, } impl SubstitutorTypeValue { - fn new(raw: LuaType, source: TplCandidateSource, top_level: bool) -> Self { + fn new(raw: LuaType, already_finalized: bool) -> Self { let raw = into_ref_type(raw); - let finalized = (source == TplCandidateSource::Finalized).then(|| raw.clone()); - Self { - raw, - finalized, - source, - top_level, - } + let finalized = already_finalized.then(|| raw.clone()); + Self { raw, finalized } } pub fn raw(&self) -> &LuaType { @@ -348,39 +428,46 @@ impl SubstitutorTypeValue { self.finalized.as_ref().unwrap_or(&self.raw) } - fn finalize( - &mut self, + fn is_finalized(&self) -> bool { + self.finalized.is_some() + } + + fn finalized( + &self, db: &DbIndex, tpl: &GenericTpl, + source: TplCandidateSource, + top_level: bool, return_top_level: bool, substitutor: &TypeSubstitutor, - ) { - if self.finalized.is_some() { - return; - } - - self.finalized = Some(finalize_inferred_tpl_candidate( + ) -> Self { + let finalized = finalize_inferred_tpl_candidate( db, tpl, &self.raw, - self.source, - self.top_level, + source, + top_level, return_top_level, substitutor, - )); - self.source = TplCandidateSource::Finalized; - self.top_level = true; + ); + Self { + raw: self.raw.clone(), + finalized: Some(finalized), + } } } #[derive(Debug, Clone)] pub(super) enum SubstitutorValue { None, - Type(SubstitutorTypeValue), + Type { + value: SubstitutorTypeValue, + source: TplCandidateSource, + top_level: bool, + }, Params(Vec<(String, Option)>), MultiTypes { - raw_types: Vec, - types: Vec, + values: Vec, source: TplCandidateSource, top_level: bool, }, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 9231e09e6..a3cb628af 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -131,14 +131,13 @@ pub fn infer_call_expr_func( }; let result = if let Ok(func_ty) = result { - let func_ty = match func_ty.get_ret() { - LuaType::Call(_) => { - match infer_call_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { - Ok(func_ty) => Arc::new(func_ty), - Err(_) => func_ty, - } + let func_ty = if func_ty.get_ret().contain_tpl() || func_ty.get_ret().is_call() { + match infer_call_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { + Ok(func_ty) => Arc::new(func_ty), + Err(_) => func_ty, } - _ => func_ty, + } else { + func_ty }; let func_ret = func_ty.get_ret(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 1f3ad09ce..8c11e5f24 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -694,14 +694,22 @@ fn infer_generic_member( ) -> InferResult { let base_type = generic_type.get_base_type(); - let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); - if let LuaType::Ref(base_type_decl_id) = &base_type { let type_index = db.get_type_index(); - if let Some(type_decl) = type_index.get_type_decl(base_type_decl_id) - && type_decl.is_alias() - && let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) + let Some(type_decl) = type_index.get_type_decl(base_type_decl_id) else { + return Err(InferFailReason::None); + }; + let generic_params = generic_type.get_params(); + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, generic_params.clone(), base_type_decl_id.clone()) + } else { + TypeSubstitutor::from_type_array(generic_params.clone()) + }; + + if type_decl.is_alias() + && let Some(origin_type) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &substitutor)) { return infer_member_by_lookup(db, cache, &origin_type, lookup, &infer_guard.fork()); } @@ -717,11 +725,12 @@ fn infer_generic_member( if let Some(result) = result { return Ok(result); } - } - let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; + let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; + return Ok(instantiate_type_generic(db, &member_type, &substitutor)); + } - Ok(instantiate_type_generic(db, &member_type, &substitutor)) + infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard) } fn infer_instance_member( @@ -986,17 +995,24 @@ fn infer_member_by_index_generic( return Err(InferFailReason::None); }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index .get_type_decl(&type_decl_id) .ok_or(InferFailReason::None)?; + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, generic_params.clone(), type_decl_id.clone()) + } else { + TypeSubstitutor::from_type_array(generic_params.clone()) + }; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin_type) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + { return infer_member_by_operator_key_type( db, cache, - &instantiate_type_generic(db, &origin_type, &substitutor), + &origin_type, key_type, &infer_guard.fork(), ); diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs index 606ba27b5..58284dc3d 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs @@ -495,10 +495,17 @@ fn find_generic_members( .iter() .map(|param| ctx.instantiate_type(db, param)) .collect(); - let substitutor = TypeSubstitutor::from_type_array(instantiated_params); let type_decl = db.get_type_index().get_type_decl(&base_ref_id)?; + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, instantiated_params, base_ref_id.clone()) + } else { + TypeSubstitutor::from_type_array(instantiated_params) + }; let ctx_with_substitutor = ctx.with_substitutor(substitutor.clone()); - if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + { return find_members_guard(db, &origin, &ctx_with_substitutor, filter); } diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs index cadd3988e..7edaffb2a 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs @@ -214,14 +214,18 @@ fn infer_generic_raw_member_type( ) -> RawGetMemberTypeResult { let base_ref_id = generic_type.get_base_type_id_ref(); let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); let type_decl = db .get_type_index() .get_type_decl(&base_ref_id) .ok_or(InferFailReason::None)?; + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, generic_params.clone(), base_ref_id.clone()) + } else { + TypeSubstitutor::from_type_array(generic_params.clone()) + }; if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { - return infer_raw_member_type(db, &origin, member_key); + return infer_raw_member_type_guard(db, &origin, member_key, infer_guard); } let base_ref_type = LuaType::Ref(base_ref_id.clone()); diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs index e9827448b..763258b1f 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs @@ -37,8 +37,11 @@ pub fn check_complex_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + let substitutor = TypeSubstitutor::from_alias( + context.db, + generic.get_params().clone(), + base_id.clone(), + ); if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { return check_general_type_compact( context, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs index afee3eddc..e7af2fa1c 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs @@ -23,8 +23,11 @@ pub fn check_doc_func_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + let substitutor = TypeSubstitutor::from_alias( + context.db, + generic.get_params().clone(), + base_id.clone(), + ); if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { return check_general_type_compact( context, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs index 0929c7ed5..85bec1bb7 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs @@ -24,9 +24,13 @@ pub fn check_generic_type_compact( .get_type_decl(&source_generic.get_base_type_id()) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(source_generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let substitutor = TypeSubstitutor::from_alias( + context.db, + source_generic.get_params().clone(), + base_id.clone(), + ); + if let Some(alias_ref) = decl.get_alias_ref() { + let alias_origin = instantiate_type_generic(context.db, alias_ref, &substitutor); return check_general_type_compact( context, &alias_origin, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 6f994cb0c..70865eec2 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -310,8 +310,11 @@ pub fn check_simple_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + let substitutor = TypeSubstitutor::from_alias( + context.db, + generic.get_params().clone(), + base_id.clone(), + ); if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { From a62da075d93ba67cedef9ae4de61caf5f068212c Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Sat, 16 May 2026 08:47:41 +0800 Subject: [PATCH 03/10] Improves distributed conditional generic inference --- .../src/compilation/test/generic_test.rs | 20 +++++++++++++++++++ .../instantiate_conditional_generic.rs | 4 +++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 5092c6e3f..694e165e0 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -1779,6 +1779,26 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "number"); } + #[test] + fn test_distributed_function_generic_conditional_return_filters_union_members() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T extends string and T or never + function extractString(value) end + + ---@type string|integer + local value + + A = extractString(value) + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + } + #[test] fn test_union_never() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index d12d8b3fb..bfb24ad50 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -188,7 +188,9 @@ fn instantiate_distributed_conditional( conditional: &LuaConditionalType, ) -> Option { let tpl_id = match conditional.get_checked_type() { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if tpl.get_tpl_id().is_type() || tpl.get_tpl_id().is_func() => + { tpl.get_tpl_id() } _ => return None, From 0e979acd8c3d19b385285ac18bef64dc92d6f95c Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Sat, 16 May 2026 20:06:37 +0800 Subject: [PATCH 04/10] fix generic --- .../analyzer/doc/file_generic_index.rs | 14 -- .../src/compilation/analyzer/lua/stats.rs | 141 +++++++++--- .../compilation/test/generic_infer_test.rs | 51 +++++ .../src/compilation/test/generic_test.rs | 75 +++++- .../instantiate_type/inference_widening.rs | 213 +++++++++++++----- .../semantic/generic/instantiate_type/mod.rs | 2 +- .../src/semantic/generic/type_substitutor.rs | 8 + 7 files changed, 402 insertions(+), 102 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index fa35dd075..5f26acc62 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -86,20 +86,6 @@ impl GenericIndex for FileGenericIndex { None } - fn append_generic_params( - &mut self, - scope_id: GenericScopeId, - params: Vec, - ) -> Vec { - let mut appended = Vec::new(); - for param in params { - if let Some(tpl_id) = self.append_generic_param(scope_id, param.clone()) { - appended.push(param.with_tpl_id(Some(tpl_id))); - } - } - appended - } - /// Find generic parameter by position and name. /// return (GenericTplId, constraint, default) fn find_generic( diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 3bb9875bf..4cd8389a9 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -53,7 +53,7 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) Ok(expr_type) => { let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); let decl_id = LuaDeclId::new(analyzer.file_id, position); - // 当`call`参数包含表时, 表可能未被分析, 需要延迟 + // 当表达式中存在带表参数的调用时, 表可能尚未完成预分析, 需要延迟 if let LuaType::Instance(instance) = &expr_type && instance.get_base().is_unknown() && call_expr_has_effect_table_arg(&expr).is_some() @@ -164,17 +164,7 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) } fn call_expr_has_effect_table_arg(expr: &LuaExpr) -> Option<()> { - if let LuaExpr::CallExpr(call_expr) = expr { - let args_list = call_expr.get_args_list()?; - for arg in args_list.get_args() { - if let LuaExpr::TableExpr(table_expr) = arg - && !table_expr.is_empty() - { - return Some(()); - } - } - } - None + expr_has_effect_table_call_arg(expr.clone()) } fn get_var_owner(analyzer: &mut LuaAnalyzer, var: LuaVarExpr) -> LuaTypeOwner { @@ -642,27 +632,122 @@ fn get_delayed_definition_decl_id( } fn pre_analyze_call_arg_table_fields(analyzer: &mut LuaAnalyzer, expr: &LuaExpr) { - let LuaExpr::CallExpr(call_expr) = expr else { - return; - }; - let Some(args_list) = call_expr.get_args_list() else { - return; - }; + pre_analyze_nested_table_fields(analyzer, expr.clone()); +} + +fn pre_analyze_nested_table_fields(analyzer: &mut LuaAnalyzer, expr: LuaExpr) { + match expr { + LuaExpr::CallExpr(call_expr) => { + if let Some(prefix_expr) = call_expr.get_prefix_expr() { + pre_analyze_nested_table_fields(analyzer, prefix_expr); + } + + if let Some(args_list) = call_expr.get_args_list() { + for arg in args_list.get_args() { + pre_analyze_nested_table_fields(analyzer, arg); + } + } + } + LuaExpr::TableExpr(table_expr) => { + for field in table_expr.get_fields() { + if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() { + pre_analyze_nested_table_fields(analyzer, key_expr); + } - for arg in args_list.get_args() { - pre_analyze_table_expr_fields(analyzer, arg); + if let Some(value_expr) = field.get_value_expr() { + pre_analyze_nested_table_fields(analyzer, value_expr); + } + + analyze_table_field(analyzer, field.clone()); + } + } + LuaExpr::BinaryExpr(binary_expr) => { + if let Some((left, right)) = binary_expr.get_exprs() { + pre_analyze_nested_table_fields(analyzer, left); + pre_analyze_nested_table_fields(analyzer, right); + } + } + LuaExpr::UnaryExpr(unary_expr) => { + if let Some(inner_expr) = unary_expr.get_expr() { + pre_analyze_nested_table_fields(analyzer, inner_expr); + } + } + LuaExpr::ParenExpr(paren_expr) => { + if let Some(inner_expr) = paren_expr.get_expr() { + pre_analyze_nested_table_fields(analyzer, inner_expr); + } + } + LuaExpr::IndexExpr(index_expr) => { + if let Some(prefix_expr) = index_expr.get_prefix_expr() { + pre_analyze_nested_table_fields(analyzer, prefix_expr); + } + + if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { + pre_analyze_nested_table_fields(analyzer, key_expr); + } + } + LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => {} } } -fn pre_analyze_table_expr_fields(analyzer: &mut LuaAnalyzer, expr: LuaExpr) { - let LuaExpr::TableExpr(table_expr) = expr else { - return; - }; +fn expr_has_effect_table_call_arg(expr: LuaExpr) -> Option<()> { + match expr { + LuaExpr::CallExpr(call_expr) => { + if let Some(prefix_expr) = call_expr.get_prefix_expr() + && expr_has_effect_table_call_arg(prefix_expr).is_some() + { + return Some(()); + } + + let args_list = call_expr.get_args_list()?; + for arg in args_list.get_args() { + if let LuaExpr::TableExpr(table_expr) = &arg + && !table_expr.is_empty() + { + return Some(()); + } + + if expr_has_effect_table_call_arg(arg).is_some() { + return Some(()); + } + } + None + } + LuaExpr::TableExpr(table_expr) => { + for field in table_expr.get_fields() { + if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() + && expr_has_effect_table_call_arg(key_expr).is_some() + { + return Some(()); + } + + if let Some(value_expr) = field.get_value_expr() + && expr_has_effect_table_call_arg(value_expr).is_some() + { + return Some(()); + } + } + None + } + LuaExpr::BinaryExpr(binary_expr) => { + let (left, right) = binary_expr.get_exprs()?; + expr_has_effect_table_call_arg(left).or_else(|| expr_has_effect_table_call_arg(right)) + } + LuaExpr::UnaryExpr(unary_expr) => expr_has_effect_table_call_arg(unary_expr.get_expr()?), + LuaExpr::ParenExpr(paren_expr) => expr_has_effect_table_call_arg(paren_expr.get_expr()?), + LuaExpr::IndexExpr(index_expr) => { + if let Some(prefix_expr) = index_expr.get_prefix_expr() + && expr_has_effect_table_call_arg(prefix_expr).is_some() + { + return Some(()); + } + + if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { + return expr_has_effect_table_call_arg(key_expr); + } - for field in table_expr.get_fields() { - analyze_table_field(analyzer, field.clone()); - if let Some(value_expr) = field.get_value_expr() { - pre_analyze_table_expr_fields(analyzer, value_expr); + None } + LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => None, } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs index 5eef7210c..226513f3e 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs @@ -47,6 +47,57 @@ mod test { assert_eq!(ws.humanize_type(a_ty), "string"); } + #[test] + fn test_object_literal_infer_nested_call_argument() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ExtractX T extends { x: infer X } and X or never + + ---@generic T + ---@param value T + ---@return ExtractX + function extractX(value) end + + ---@generic T + ---@param value T + ---@return T + function identity(value) end + + A = identity(extractX({ x = 1 })) + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(a_ty), "integer"); + } + + #[test] + fn test_object_literal_infer_nested_call_inside_table_field() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ExtractX T extends { x: infer X } and X or never + ---@alias ExtractInner T extends { inner: infer I } and I or never + + ---@generic T + ---@param value T + ---@return ExtractX + function extractX(value) end + + ---@generic T + ---@param value T + ---@return ExtractInner + function extractInner(value) end + + B = extractInner({ inner = extractX({ x = 1 }) }) + "#, + ); + + let b_ty = ws.expr_ty("B"); + assert_eq!(ws.humanize_type(b_ty), "integer"); + } + #[test] fn test_object_literal_infer_from_class() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 694e165e0..3c99430a6 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -1483,6 +1483,46 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "\"mode\""); } + #[test] + fn test_transparent_alias_root_union_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Id T + + ---@generic T + ---@param value T + ---@return Id|nil + function maybe(value) + end + + result = maybe("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), r#""mode"?"#); + } + + #[test] + fn test_plain_tpl_root_union_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T|nil + function maybe(value) + end + + result = maybe("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), r#""mode"?"#); + } + #[test] fn test_plain_tpl_top_level_return_preserves_primitive_literal_union() { let mut ws = VirtualWorkspace::new(); @@ -1526,15 +1566,43 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "\"mode\""); } + #[test] + fn test_finalized_table_const_self_reference_widens_without_recursing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + local t = { kind = "mode" } + t.self = t + + result = id(t) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("{ kind: string, self: table }")); + } + #[test] fn test_contextual_widening_keeps_bare_literal_but_widens_nested_literals() { - use crate::{LuaMemberKey, LuaObjectType, WideningContext, widen_type_with_context}; + use crate::{ + LuaMemberKey, LuaObjectType, WideningContext, WideningGuard, widen_type_with_context, + }; use smol_str::SmolStr; let mut ws = VirtualWorkspace::new(); let bare = LuaType::StringConst(SmolStr::new("mode").into()); assert_eq!( - widen_type_with_context(bare.clone(), WideningContext::Root), + widen_type_with_context( + bare.clone(), + WideningContext::Root, + &mut WideningGuard::default() + ), bare ); @@ -1550,7 +1618,8 @@ mod test { ) .into(), ); - let widened = widen_type_with_context(object, WideningContext::Root); + let widened = + widen_type_with_context(object, WideningContext::Root, &mut WideningGuard::default()); assert_eq!(widened, ws.ty("{ kind: string }")); } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs index ba4cae0a4..0764615b9 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs @@ -1,9 +1,10 @@ use std::{ops::Deref, sync::Arc}; -use hashbrown::HashMap; +use hashbrown::{HashMap, HashSet}; +use rowan::TextRange; use crate::{ - DbIndex, GenericParam, GenericTpl, LuaArrayType, LuaConditionalType, LuaFunctionType, + DbIndex, GenericParam, GenericTpl, InFiled, LuaArrayType, LuaConditionalType, LuaFunctionType, LuaGenericType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, TypeOps, TypeSubstitutor, VariadicType, instantiate_type_generic, }; @@ -38,15 +39,14 @@ pub(in crate::semantic::generic) fn finalize_inferred_tpl_candidate( let candidate = if primitive_constraint || !top_level || return_top_level { raw_candidate.clone() } else { - match raw_candidate { - LuaType::FloatConst(_) => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, - LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, - _ => raw_candidate.clone(), - } + widen_literal_type(raw_candidate.clone()) }; - widen_finalized_candidate_type(db, candidate, WideningContext::Root) + finalize_tpl_candidate_type( + db, + candidate, + WideningContext::Root, + &mut WideningGuard::default(), + ) } fn is_primitive_or_literal_type(ty: &LuaType) -> bool { @@ -87,10 +87,52 @@ pub enum WideningContext { VariadicElement, } -fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningContext) -> LuaType { - match ty { +const MAX_WIDENING_DEPTH: u16 = 100; + +#[derive(Default)] +pub struct WideningGuard { + depth: u16, + active_table_ids: HashSet>, +} + +impl WideningGuard { + fn enter_level(&mut self) -> bool { + if self.depth >= MAX_WIDENING_DEPTH { + return false; + } + self.depth += 1; + true + } + + fn leave_level(&mut self) { + self.depth = self.depth.saturating_sub(1); + } + + fn enter_table(&mut self, table_id: &InFiled) -> bool { + self.active_table_ids.insert(table_id.clone()) + } + + fn leave_table(&mut self, table_id: &InFiled) { + self.active_table_ids.remove(table_id); + } +} + +fn finalize_tpl_candidate_type( + db: &DbIndex, + ty: LuaType, + context: WideningContext, + guard: &mut WideningGuard, +) -> LuaType { + if !guard.enter_level() { + return match ty { + LuaType::TableConst(_) => LuaType::Table, + ty => widen_literals_with_context(ty, context), + }; + } + + let widened = match ty { LuaType::TableConst(table_id) => { - table_const_to_object(db, table_id).unwrap_or(LuaType::Table) + table_const_to_object(db, table_id, guard).unwrap_or(LuaType::Table) } LuaType::Object(object) => { let fields = object @@ -99,10 +141,11 @@ fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningCo .map(|(key, ty)| { ( key.clone(), - widen_finalized_candidate_type( + finalize_tpl_candidate_type( db, ty.clone(), WideningContext::ObjectProperty, + guard, ), ) }) @@ -112,11 +155,16 @@ fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningCo .iter() .map(|(key, value)| { ( - widen_type_with_context(key.clone(), WideningContext::ObjectProperty), - widen_finalized_candidate_type( + widen_type_with_context( + key.clone(), + WideningContext::ObjectProperty, + guard, + ), + finalize_tpl_candidate_type( db, value.clone(), WideningContext::ObjectProperty, + guard, ), ) }) @@ -129,7 +177,7 @@ fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningCo _ => WideningContext::ArrayElement, }; let base = - widen_finalized_candidate_type(db, array.get_base().clone(), element_context); + finalize_tpl_candidate_type(db, array.get_base().clone(), element_context, guard); LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) } LuaType::Tuple(tuple) => { @@ -137,7 +185,7 @@ fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningCo .get_types() .iter() .cloned() - .map(|ty| widen_finalized_candidate_type(db, ty, WideningContext::TupleElement)) + .map(|ty| finalize_tpl_candidate_type(db, ty, WideningContext::TupleElement, guard)) .collect(); LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) } @@ -152,34 +200,37 @@ fn widen_finalized_candidate_type(db: &DbIndex, ty: LuaType, context: WideningCo union .into_vec() .into_iter() - .map(|ty| widen_finalized_candidate_type(db, ty, member_context)) + .map(|ty| finalize_tpl_candidate_type(db, ty, member_context, guard)) .collect(), ) .into(), ) } - ty => widen_type_with_context(ty, context), - } + ty => widen_type_with_context(ty, context, guard), + }; + + guard.leave_level(); + widened } -pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType { - let widen_literals = !matches!(context, WideningContext::Root); +pub fn widen_type_with_context( + ty: LuaType, + context: WideningContext, + guard: &mut WideningGuard, +) -> LuaType { + if !guard.enter_level() { + return widen_literals_with_context(ty, context); + } + + let ty = widen_literals_with_context(ty, context); - match ty { - LuaType::FloatConst(_) if widen_literals => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) if widen_literals => { - LuaType::Integer - } - LuaType::DocStringConst(_) | LuaType::StringConst(_) if widen_literals => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) if widen_literals => { - LuaType::Boolean - } + let widened = match ty { LuaType::Array(array) => { let element_context = match context { WideningContext::TupleElement => WideningContext::TupleElement, _ => WideningContext::ArrayElement, }; - let base = widen_type_with_context(array.get_base().clone(), element_context); + let base = widen_type_with_context(array.get_base().clone(), element_context, guard); LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) } LuaType::Tuple(tuple) => { @@ -187,7 +238,7 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .get_types() .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::TupleElement)) + .map(|ty| widen_type_with_context(ty, WideningContext::TupleElement, guard)) .collect(); LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) } @@ -198,7 +249,7 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .map(|(key, ty)| { ( key.clone(), - widen_type_with_context(ty.clone(), WideningContext::ObjectProperty), + widen_type_with_context(ty.clone(), WideningContext::ObjectProperty, guard), ) }) .collect(); @@ -207,8 +258,16 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .iter() .map(|(key, value)| { ( - widen_type_with_context(key.clone(), WideningContext::ObjectProperty), - widen_type_with_context(value.clone(), WideningContext::ObjectProperty), + widen_type_with_context( + key.clone(), + WideningContext::ObjectProperty, + guard, + ), + widen_type_with_context( + value.clone(), + WideningContext::ObjectProperty, + guard, + ), ) }) .collect(); @@ -225,7 +284,7 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType union .into_vec() .into_iter() - .map(|ty| widen_type_with_context(ty, member_context)) + .map(|ty| widen_type_with_context(ty, member_context, guard)) .collect(), ) .into(), @@ -238,7 +297,11 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .iter() .map(|(ty, description)| { ( - widen_type_with_context(ty.clone(), WideningContext::UnionMember), + widen_type_with_context( + ty.clone(), + WideningContext::UnionMember, + guard, + ), description.clone(), ) }) @@ -252,7 +315,7 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .get_types() .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::UnionMember)) + .map(|ty| widen_type_with_context(ty, WideningContext::UnionMember, guard)) .collect(), ) .into(), @@ -262,12 +325,15 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType VariadicType::Base(base) => VariadicType::Base(widen_type_with_context( base.clone(), WideningContext::VariadicElement, + guard, )), VariadicType::Multi(types) => VariadicType::Multi( types .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::VariadicElement)) + .map(|ty| { + widen_type_with_context(ty, WideningContext::VariadicElement, guard) + }) .collect(), ), } @@ -280,7 +346,7 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .get_params() .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::Root)) + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)) .collect(), ) .into(), @@ -289,7 +355,7 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType params .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::Root)) + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)) .collect::>() .into(), ), @@ -303,32 +369,41 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .map(|(name, ty)| { ( name.clone(), - ty.clone() - .map(|ty| widen_type_with_context(ty, WideningContext::Root)), + ty.clone().map(|ty| { + widen_type_with_context(ty, WideningContext::Root, guard) + }), ) }) .collect(), - widen_type_with_context(func.get_ret().clone(), WideningContext::Root), + widen_type_with_context(func.get_ret().clone(), WideningContext::Root, guard), ) .into(), ), - LuaType::TypeGuard(guard) => LuaType::TypeGuard( - widen_type_with_context(guard.deref().clone(), WideningContext::Root).into(), + LuaType::TypeGuard(type_guard) => LuaType::TypeGuard( + widen_type_with_context(type_guard.deref().clone(), WideningContext::Root, guard) + .into(), ), LuaType::Conditional(conditional) => LuaType::Conditional( LuaConditionalType::new( widen_type_with_context( conditional.get_checked_type().clone(), WideningContext::Root, + guard, ), widen_type_with_context( conditional.get_extends_type().clone(), WideningContext::Root, + guard, + ), + widen_type_with_context( + conditional.get_true_type().clone(), + WideningContext::Root, + guard, ), - widen_type_with_context(conditional.get_true_type().clone(), WideningContext::Root), widen_type_with_context( conditional.get_false_type().clone(), WideningContext::Root, + guard, ), conditional.get_infer_params().to_vec(), conditional.has_new, @@ -345,30 +420,54 @@ pub fn widen_type_with_context(ty: LuaType, context: WideningContext) -> LuaType .1 .type_constraint .clone() - .map(|ty| widen_type_with_context(ty, WideningContext::Root)), + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)), mapped .param .1 .default_type .clone() - .map(|ty| widen_type_with_context(ty, WideningContext::Root)), + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)), mapped.param.1.attributes.clone(), ), ), - widen_type_with_context(mapped.value.clone(), WideningContext::Root), + widen_type_with_context(mapped.value.clone(), WideningContext::Root, guard), mapped.is_readonly, mapped.is_optional, ))), ty => ty, + }; + + guard.leave_level(); + widened +} + +fn widen_literals_with_context(ty: LuaType, context: WideningContext) -> LuaType { + match context { + WideningContext::Root => ty, + _ => widen_literal_type(ty), + } +} + +fn widen_literal_type(ty: LuaType) -> LuaType { + match ty { + LuaType::FloatConst(_) => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, + LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, + ty => ty, } } fn table_const_to_object( db: &DbIndex, - table_id: crate::InFiled, + table_id: InFiled, + guard: &mut WideningGuard, ) -> Option { - let owner = LuaMemberOwner::Element(table_id); + let owner = LuaMemberOwner::Element(table_id.clone()); let members = db.get_member_index().get_members(&owner)?; + if !guard.enter_table(&table_id) { + return Some(LuaType::Table); + } let mut fields = HashMap::new(); let mut index_access = Vec::new(); @@ -378,7 +477,7 @@ fn table_const_to_object( .get_type_cache(&member.get_id().into()) .map(|cache| cache.as_type().clone()) .unwrap_or(LuaType::Unknown); - let value = widen_finalized_candidate_type(db, value, WideningContext::ObjectProperty); + let value = finalize_tpl_candidate_type(db, value, WideningContext::ObjectProperty, guard); match member.get_key() { LuaMemberKey::Name(_) | LuaMemberKey::Integer(_) => { @@ -391,7 +490,7 @@ fn table_const_to_object( } LuaMemberKey::ExprType(key) => { index_access.push(( - widen_type_with_context(key.clone(), WideningContext::ObjectProperty), + widen_type_with_context(key.clone(), WideningContext::ObjectProperty, guard), value, )); } @@ -399,6 +498,8 @@ fn table_const_to_object( } } + guard.leave_table(&table_id); + Some(LuaType::Object( LuaObjectType::new_with_fields(fields, index_access).into(), )) diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index 94612aaee..c15dd4c44 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -28,7 +28,7 @@ pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infe pub(in crate::semantic::generic) use inference_widening::{ TplCandidateSource, finalize_inferred_tpl_candidate, }; -pub use inference_widening::{WideningContext, widen_type_with_context}; +pub use inference_widening::{WideningContext, WideningGuard, widen_type_with_context}; use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_inner; pub use instantiate_special_generic::get_keyof_members; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index cfc5dfd22..6dbf090da 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -492,6 +492,14 @@ fn is_tpl_at_top_level_with_guard( ) -> bool { match ty { LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::Union(union) => union.into_vec().iter().any(|member| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), LuaType::Generic(generic) => { let type_decl_id = generic.get_base_type_id_ref(); let Some(alias_param) = From ac3acba03f36579758c06256ed1c78f7aace504b Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Mon, 18 May 2026 22:44:26 +0800 Subject: [PATCH 05/10] update std.Unpack and std.ConstTpl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 现在 std.Unpack 的目标泛型必须使用 std.ConstTpl 包裹以保持常量传播 --- crates/emmylua_code_analysis/resources/std/builtin.lua | 2 +- crates/emmylua_code_analysis/resources/std/global.lua | 2 +- crates/emmylua_code_analysis/resources/std/table.lua | 2 +- .../src/compilation/test/unpack_test.rs | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index af7194301..60f6d9c49 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -129,7 +129,7 @@ --- @alias std.RawGet unknown --- ---- built-in type for generic template, for match integer const and true/false +--- built-in type for generic template, used for const generic --- @alias std.ConstTpl unknown --- compact luals diff --git a/crates/emmylua_code_analysis/resources/std/global.lua b/crates/emmylua_code_analysis/resources/std/global.lua index 220ca2de9..eadf553c0 100644 --- a/crates/emmylua_code_analysis/resources/std/global.lua +++ b/crates/emmylua_code_analysis/resources/std/global.lua @@ -463,7 +463,7 @@ function xpcall(f, msgh, ...) end --- @generic T, Start: integer, End: integer --- @param i? std.ConstTpl --- @param j? std.ConstTpl ---- @param list T +--- @param list std.ConstTpl --- @return std.Unpack function unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/resources/std/table.lua b/crates/emmylua_code_analysis/resources/std/table.lua index ee60a7343..171d858e6 100644 --- a/crates/emmylua_code_analysis/resources/std/table.lua +++ b/crates/emmylua_code_analysis/resources/std/table.lua @@ -109,7 +109,7 @@ function table.sort(list, comp) end --- @generic T, Start: integer, End: integer --- @param i? std.ConstTpl --- @param j? std.ConstTpl ---- @param list T +--- @param list std.ConstTpl --- @return std.Unpack function table.unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs index 606957d47..fc4fb062a 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs @@ -35,7 +35,7 @@ mod test { let mut ws = VirtualWorkspace::new_with_init_std_lib(); ws.def( r#" - ---@overload fun(t: T): std.Unpack + ---@overload fun(t: std.ConstTpl): std.Unpack ---@overload fun(t: number): number local function f(t) end @@ -55,7 +55,7 @@ mod test { ws.def( r#" ---@class Obj - ---@field unpack (fun(self: Obj, t: T): std.Unpack) | (fun(self: Obj, t: number): number) + ---@field unpack (fun(self: Obj, t: std.ConstTpl): std.Unpack) | (fun(self: Obj, t: number): number) local Obj = {} a, b = Obj:unpack({ 1, 2 }) @@ -75,7 +75,7 @@ mod test { local Obj = {} ---@generic T - ---@param t T + ---@param t std.ConstTpl ---@return std.Unpack function Obj:unpack(t) end From 8b39362528b81c5295659fe9b40029f8616876ca Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Tue, 19 May 2026 01:18:59 +0800 Subject: [PATCH 06/10] refactor(generic): add inference core --- .../analyzer/lua/for_range_stat.rs | 25 +- .../test/for_range_var_infer_test.rs | 32 + .../src/compilation/test/generic_test.rs | 83 +- .../src/semantic/generic/inference.rs | 2223 +++++++++++++++++ .../infer_call_func_generic.rs | 120 +- .../instantiate_type/inference_widening.rs | 610 +++-- .../instantiate_conditional_generic.rs | 37 +- .../semantic/generic/instantiate_type/mod.rs | 3 +- .../src/semantic/generic/mod.rs | 211 +- .../src/semantic/generic/test.rs | 53 + .../src/semantic/generic/tpl_context.rs | 77 - .../tpl_pattern/generic_tpl_pattern.rs | 138 - .../generic/tpl_pattern/lambda_tpl_pattern.rs | 22 - .../src/semantic/generic/tpl_pattern/mod.rs | 1022 -------- .../src/semantic/generic/type_substitutor.rs | 258 +- .../src/semantic/infer/infer_call/mod.rs | 2 +- 16 files changed, 2957 insertions(+), 1959 deletions(-) create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/inference.rs delete mode 100644 crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs delete mode 100644 crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs delete mode 100644 crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs delete mode 100644 crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index 8a3565802..0d829df86 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -2,9 +2,8 @@ use emmylua_parser::{LuaAstToken, LuaExpr, LuaForRangeStat}; use crate::{ DbIndex, InferFailReason, LuaDeclId, LuaInferCache, LuaOperatorMetaMethod, LuaType, - LuaTypeCache, TplContext, TypeOps, TypeSubstitutor, VariadicType, - compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_type_generic, - tpl_pattern_match_args, + LuaTypeCache, TypeOps, VariadicType, compilation::analyzer::unresolve::UnResolveIterVar, + infer_expr, instantiate_doc_function_by_arg_types, }; use super::LuaAnalyzer; @@ -144,24 +143,8 @@ pub fn infer_for_range_iter_expr_func( let Some(status_param) = status_param else { return Ok(doc_function.get_variadic_ret()); }; - let mut substitutor = TypeSubstitutor::new(); - let params = doc_function - .get_params() - .iter() - .map(|(_, opt_ty)| opt_ty.clone().unwrap_or(LuaType::Any)) - .collect::>(); - - let mut context = TplContext::new(db, cache, &mut substitutor, None); - tpl_pattern_match_args(&mut context, ¶ms, &[status_param])?; - - let doc_function_ty = LuaType::DocFunction(doc_function.clone()); - let instantiate_func = if let LuaType::DocFunction(f) = - instantiate_type_generic(db, &doc_function_ty, &substitutor) - { - f - } else { - doc_function - }; + let instantiate_func = + instantiate_doc_function_by_arg_types(db, cache, &doc_function, &[status_param])?; Ok(instantiate_func.get_variadic_ret()) } diff --git a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs index 481e06059..87928325c 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs @@ -105,6 +105,38 @@ mod test { assert_eq!(b, LuaType::Integer); } + #[test] + fn test_issue_490_split_iterator_state() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T: table, K, V + ---@param t T + ---@return fun(table: table, index?: K):K, V + ---@return T + local function spairs(t) + return next, t + end + + --- @type table + local t = { a = 1, b = 2, c = 3 } + + local iter, state = spairs(t) + + for name, value in iter, state do + a = name + b = value + end + "#, + ); + + let a = ws.expr_ty("a"); + let b = ws.expr_ty("b"); + assert_eq!(a, LuaType::String); + assert_eq!(b, LuaType::Integer); + } + #[test] fn test_enum_key_pairs() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 3c99430a6..94b316606 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -1443,6 +1443,54 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "\"mode\""); } + #[test] + fn test_const_tpl_candidate_preserves_structural_literals() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function keep_const(value) + end + + table_result = keep_const({ kind = "mode", count = 1 }) + table_kind = table_result.kind + table_count = table_result.count + + ---@type { kind: "mode", count: 1 } + local object + + object_result = keep_const(object) + object_kind = object_result.kind + object_count = object_result.count + + ---@type ["mode", 1] + local tuple + + tuple_result = keep_const(tuple) + tuple_first = tuple_result[1] + tuple_second = tuple_result[2] + "#, + ); + + let table_kind = ws.expr_ty("table_kind"); + let table_count = ws.expr_ty("table_count"); + let object_kind = ws.expr_ty("object_kind"); + let object_count = ws.expr_ty("object_count"); + let tuple_first = ws.expr_ty("tuple_first"); + let tuple_second = ws.expr_ty("tuple_second"); + + assert_eq!(ws.humanize_type(table_kind), "\"mode\""); + assert_eq!(ws.humanize_type(table_count), "1"); + assert_eq!(ws.humanize_type(object_kind), "\"mode\""); + assert_eq!(ws.humanize_type(object_count), "1"); + assert_eq!(ws.humanize_type(tuple_first), "\"mode\""); + assert_eq!(ws.humanize_type(tuple_second), "1"); + } + #[test] fn test_plain_tpl_top_level_return_preserves_primitive_literal() { let mut ws = VirtualWorkspace::new(); @@ -1588,41 +1636,6 @@ mod test { assert_eq!(result_ty, ws.ty("{ kind: string, self: table }")); } - #[test] - fn test_contextual_widening_keeps_bare_literal_but_widens_nested_literals() { - use crate::{ - LuaMemberKey, LuaObjectType, WideningContext, WideningGuard, widen_type_with_context, - }; - use smol_str::SmolStr; - - let mut ws = VirtualWorkspace::new(); - let bare = LuaType::StringConst(SmolStr::new("mode").into()); - assert_eq!( - widen_type_with_context( - bare.clone(), - WideningContext::Root, - &mut WideningGuard::default() - ), - bare - ); - - let object = LuaType::Object( - LuaObjectType::new_with_fields( - [( - LuaMemberKey::Name("kind".into()), - LuaType::StringConst(SmolStr::new("mode").into()), - )] - .into_iter() - .collect(), - Vec::new(), - ) - .into(), - ); - let widened = - widen_type_with_context(object, WideningContext::Root, &mut WideningGuard::default()); - assert_eq!(widened, ws.ty("{ kind: string }")); - } - #[test] fn test_extends_true() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference.rs new file mode 100644 index 000000000..896eb00a4 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference.rs @@ -0,0 +1,2223 @@ +use std::{collections::HashMap as StdHashMap, ops::Deref, sync::Arc}; + +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr}; +use hashbrown::{HashMap, HashSet}; +use itertools::Itertools; +use rowan::NodeOrToken; +use smol_str::SmolStr; + +use crate::{ + DbIndex, GenericTpl, GenericTplId, InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, + LuaGenericType, LuaInferCache, LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, + LuaTupleType, LuaType, LuaTypeDeclId, LuaTypeNode, LuaUnionType, SemanticDeclLevel, TypeOps, + TypeSubstitutor, VariadicType, check_type_compact, infer_node_semantic_decl, + instantiate_type_generic, + semantic::{ + generic::{ + is_primitive_or_literal_type, regularize_tpl_candidate_type, widen_tpl_candidate_type, + }, + member::{find_index_operations, get_member_map}, + }, +}; + +use super::type_substitutor::TplBinding; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum InferenceCandidateView { + FreshExpression, + RegularType, + ConstPreserving, + Ordinary, +} + +#[derive(Debug, Clone, PartialEq)] +pub(in crate::semantic::generic) struct InferenceCandidate { + ty: LuaType, + view: InferenceCandidateView, +} + +impl InferenceCandidate { + pub(in crate::semantic::generic) fn from_expr_arg(expr: Option<&LuaExpr>, ty: LuaType) -> Self { + if is_literal_candidate(&ty) { + if expr.is_some_and(is_fresh_literal_expr) { + return Self::fresh_expression(ty); + } + + return Self::regular_type(ty); + } + + Self::ordinary(ty) + } + + pub(in crate::semantic::generic) fn regular_type(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::RegularType, + } + } + + pub(in crate::semantic::generic) fn const_preserving(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::ConstPreserving, + } + } + + pub(in crate::semantic::generic) fn ordinary(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::Ordinary, + } + } + + fn fresh_expression(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::FreshExpression, + } + } + + fn candidate_type(&self) -> LuaType { + self.ty.clone() + } + + fn is_const_preserving(&self) -> bool { + self.view == InferenceCandidateView::ConstPreserving + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(in crate::semantic::generic) enum InferencePriority { + Normal, + Return, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum InferenceVariance { + Covariant, + Contravariant, +} + +impl InferenceVariance { + fn flip(self) -> Self { + match self { + InferenceVariance::Covariant => InferenceVariance::Contravariant, + InferenceVariance::Contravariant => InferenceVariance::Covariant, + } + } +} + +#[derive(Debug, Clone)] +pub(in crate::semantic::generic) enum InferenceResult { + Type(LuaType), + MultiTypes(Vec), + VariadicParams(Vec<(String, Option)>), + VariadicBase(LuaType), +} + +#[derive(Debug, Clone)] +struct InferenceRecord { + candidate: InferenceCandidate, + priority: InferencePriority, +} + +#[derive(Debug, Clone, Default)] +pub(in crate::semantic::generic) struct InferenceInfo { + // 协变候选, 例如参数位置或返回值位置的正向推断. + covariant: Vec, + // 逆变候选, 例如函数参数约束的反向传播. + contravariant: Vec, + // 多返回值或变参展开时收集的候选. + multi: Vec, + // 已经固定下来的结果, 例如显式泛型或变参参数表. + fixed: Option, + // 最近一次写入的推断结果, 便于在后续阶段判断是否需要重算. + inferred: Option, + // 当前信息是否始终处于顶层位置. + top_level: bool, + // 该模板收到的最高优先级. + priority: Option, +} + +impl InferenceInfo { + fn new() -> Self { + Self { + top_level: true, + ..Self::default() + } + } + + fn add_candidate( + &mut self, + variance: InferenceVariance, + candidate: InferenceCandidate, + top_level: bool, + priority: InferencePriority, + ) { + self.top_level &= top_level; + self.priority = Some( + self.priority + .map_or(priority, |current| current.max(priority)), + ); + self.inferred = None; + let record = InferenceRecord { + candidate, + priority, + }; + match variance { + InferenceVariance::Covariant => self.covariant.push(record), + InferenceVariance::Contravariant => self.contravariant.push(record), + } + } +} + +#[derive(Debug)] +pub(in crate::semantic::generic) struct InferenceContext<'a> { + pub db: &'a DbIndex, + pub cache: &'a mut LuaInferCache, + pub call_expr: Option, + infos: HashMap, +} + +impl<'a> InferenceContext<'a> { + pub fn new( + db: &'a DbIndex, + cache: &'a mut LuaInferCache, + call_expr: Option, + ) -> Self { + Self { + db, + cache, + call_expr, + infos: HashMap::new(), + } + } + + pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { + for tpl_id in tpl_ids { + if tpl_id.is_conditional_infer() { + continue; + } + + self.infos.entry(tpl_id).or_insert_with(InferenceInfo::new); + } + } + + pub fn has_unresolved_inference_slots(&self) -> bool { + self.infos.values().any(|info| { + info.fixed.is_none() + && info.covariant.is_empty() + && info.contravariant.is_empty() + && info.multi.is_empty() + }) + } + + pub fn fix_type(&mut self, tpl_id: GenericTplId, ty: LuaType) { + if tpl_id.is_conditional_infer() { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + info.fixed = Some(InferenceResult::Type(ty)); + info.inferred = info.fixed.clone(); + } + + pub fn add_variadic_params( + &mut self, + tpl_id: GenericTplId, + params: Vec<(String, Option)>, + ) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + return; + } + + info.fixed = Some(InferenceResult::VariadicParams(params)); + info.inferred = info.fixed.clone(); + } + + pub fn add_variadic_base(&mut self, tpl_id: GenericTplId, base: LuaType) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + return; + } + + info.fixed = Some(InferenceResult::VariadicBase(base)); + info.inferred = info.fixed.clone(); + } + + pub fn insert_type( + &mut self, + tpl_id: GenericTplId, + candidate: InferenceCandidate, + variance: InferenceVariance, + top_level: bool, + priority: InferencePriority, + ) { + if !self.can_bind(tpl_id) { + return; + } + + self.infos + .entry(tpl_id) + .or_default() + .add_candidate(variance, candidate, top_level, priority); + } + + pub fn insert_multi_types( + &mut self, + tpl_id: GenericTplId, + types: Vec, + view: InferenceCandidateView, + top_level: bool, + priority: InferencePriority, + ) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + let fixed = if types.len() == 1 { + Some(InferenceResult::VariadicParams(vec![( + "var0".to_string(), + Some(types[0].clone()), + )])) + } else { + Some(InferenceResult::MultiTypes(types.clone())) + }; + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + if top_level { + info.fixed = fixed; + info.inferred = info.fixed.clone(); + } + return; + } + + info.multi = types + .into_iter() + .map(|ty| InferenceRecord { + candidate: InferenceCandidate { ty, view }, + priority, + }) + .collect(); + info.top_level &= top_level; + info.priority = Some( + info.priority + .map_or(priority, |current| current.max(priority)), + ); + info.inferred = None; + } + + fn inferred_variadic_len(&self, tpl_id: GenericTplId) -> Option { + let info = self.infos.get(&tpl_id)?; + if let Some(fixed) = &info.fixed { + return match fixed { + InferenceResult::Type(_) => Some(1), + InferenceResult::MultiTypes(types) => Some(types.len().max(1)), + InferenceResult::VariadicParams(params) => Some(params.len().max(1)), + InferenceResult::VariadicBase(_) => None, + }; + } + + if !info.multi.is_empty() { + return Some(info.multi.len().max(1)); + } + + if !info.covariant.is_empty() || !info.contravariant.is_empty() { + return Some(1); + } + + None + } + + pub fn bridge_to_substitutor<'b>( + &mut self, + substitutor: &mut TypeSubstitutor, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) { + let generic_tpls = generic_tpls.into_iter().collect::>(); + let tpl_ids = generic_tpls.iter().map(|tpl| tpl.get_tpl_id()).collect(); + substitutor.prepare_inference_slots(tpl_ids); + self.bridge_to_substitutor_inner(substitutor, generic_tpls, return_type); + } + + pub fn bridge_resolved_to_substitutor<'b>( + &mut self, + substitutor: &mut TypeSubstitutor, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) { + self.bridge_to_substitutor_inner(substitutor, generic_tpls, return_type); + } + + fn bridge_to_substitutor_inner<'b>( + &mut self, + substitutor: &mut TypeSubstitutor, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) { + for tpl in generic_tpls { + let tpl_id = tpl.get_tpl_id(); + let return_top_level = is_tpl_at_top_level(self.db, return_type, tpl_id); + let result = self + .infos + .get(&tpl_id) + .and_then(|info| resolve_info(self.db, tpl, info, return_top_level, substitutor)); + if let Some(result) = result { + write_result_to_substitutor(substitutor, tpl_id, result); + } + } + } + + fn can_bind(&self, tpl_id: GenericTplId) -> bool { + !tpl_id.is_conditional_infer() + } +} + +pub(in crate::semantic::generic) fn infer_types( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + infer_types_inner( + context, + source, + target, + original_target, + variance, + priority, + None, + &InferGuard::new(), + ) +} + +pub(in crate::semantic::generic) fn infer_types_from_expr( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + arg_expr: &LuaExpr, +) -> Result<(), InferFailReason> { + infer_types_inner( + context, + source, + target, + original_target, + InferenceVariance::Covariant, + InferencePriority::Normal, + Some(arg_expr), + &InferGuard::new(), + ) +} + +fn infer_types_inner( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let target = escape_alias(context.db, target); + if !source.contains_tpl_node() { + return Ok(()); + } + + let top_level = target == *original_target; + match source { + LuaType::TplRef(tpl) => { + if tpl.get_tpl_id().is_func() { + let candidate = match variance { + InferenceVariance::Covariant => { + InferenceCandidate::from_expr_arg(arg_expr, target.clone()) + } + InferenceVariance::Contravariant => InferenceCandidate::ordinary(target), + }; + context.insert_type(tpl.get_tpl_id(), candidate, variance, top_level, priority); + } + } + LuaType::ConstTplRef(tpl) => { + if tpl.get_tpl_id().is_func() { + context.insert_type( + tpl.get_tpl_id(), + InferenceCandidate::const_preserving(target), + variance, + top_level, + priority, + ); + } + } + LuaType::StrTplRef(str_tpl) => { + if let LuaType::StringConst(s) = target { + let type_name = SmolStr::new(format!( + "{}{}{}", + str_tpl.get_prefix(), + s, + str_tpl.get_suffix() + )); + context.insert_type( + str_tpl.get_tpl_id(), + InferenceCandidate::regular_type(get_str_tpl_infer_type(&type_name)), + variance, + top_level, + priority, + ); + } + } + LuaType::Array(array_type) => { + array_infer_types( + context, + array_type.get_base(), + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::TableGeneric(params) => { + table_generic_infer_types( + context, + params, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Generic(generic) => { + generic_infer_types( + context, + generic, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Union(union) => { + let members = union.into_vec(); + let mut error_count = 0; + let mut last_error = InferFailReason::None; + for member in &members { + match infer_types_inner( + context, + member, + &target, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + ) { + Ok(_) => {} + Err(err) => { + error_count += 1; + last_error = err; + } + } + } + if error_count == members.len() { + return Err(last_error); + } + } + LuaType::DocFunction(func) => { + function_infer_types( + context, + func, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Tuple(tuple) => { + tuple_infer_types( + context, + tuple, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Object(object) => { + object_infer_types( + context, + object, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +fn array_infer_types( + context: &mut InferenceContext, + base: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Array(target_array) => infer_types_inner( + context, + base, + target_array.get_base(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?, + LuaType::Tuple(target_tuple) => { + let target_base = target_tuple.cast_down_array_base(context.db); + infer_types_inner( + context, + base, + &target_base, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Object(target_object) => { + let target_base = target_object + .cast_down_array_base(context.db) + .ok_or(InferFailReason::None)?; + infer_types_inner( + context, + base, + &target_base, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +fn table_generic_infer_types( + context: &mut InferenceContext, + table_params: &[LuaType], + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + if table_params.len() != 2 { + return Err(InferFailReason::None); + } + + match target { + LuaType::TableGeneric(target_params) => { + let min_len = table_params.len().min(target_params.len()); + for i in 0..min_len { + infer_types_inner( + context, + &table_params[i], + &target_params[i], + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + LuaType::Array(target_array) => { + infer_types_inner( + context, + &table_params[0], + &LuaType::Integer, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + target_array.get_base(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Tuple(target_tuple) => { + let keys = (0..target_tuple.get_types().len()) + .map(|i| LuaType::IntegerConst((i as i64) + 1)) + .collect::>(); + let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); + let target_base = target_tuple.cast_down_array_base(context.db); + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &target_base, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::TableConst(inst) => { + table_generic_member_owner_infer_types( + context, + table_params, + LuaMemberOwner::Element(inst.clone()), + &[], + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + table_generic_member_owner_infer_types( + context, + table_params, + LuaMemberOwner::Type(type_id.clone()), + &[], + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Generic(generic) => { + table_generic_member_owner_infer_types( + context, + table_params, + LuaMemberOwner::Type(generic.get_base_type_id()), + generic.get_params(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Object(obj) => { + let mut keys = + Vec::with_capacity(obj.get_fields().len() + obj.get_index_access().len()); + let mut values = + Vec::with_capacity(obj.get_fields().len() + obj.get_index_access().len()); + for (key, value) in obj.get_fields() { + match key { + LuaMemberKey::Integer(i) => keys.push(LuaType::IntegerConst(*i)), + LuaMemberKey::Name(name) => { + keys.push(LuaType::StringConst(name.clone().into())) + } + _ => {} + } + values.push(value.clone()); + } + for (key, value) in obj.get_index_access() { + keys.push(key.clone()); + values.push(value.clone()); + } + let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); + let value_type = LuaType::Union(LuaUnionType::from_vec(values).into()); + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &value_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Global | LuaType::Any | LuaType::Table | LuaType::Userdata => { + infer_types_inner( + context, + &table_params[0], + &LuaType::Any, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &LuaType::Any, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn table_generic_member_owner_infer_types( + context: &mut InferenceContext, + table_params: &[LuaType], + owner: LuaMemberOwner, + target_params: &[LuaType], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + if table_params.len() != 2 { + return Err(InferFailReason::None); + } + + let owner_type = match &owner { + LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), + LuaMemberOwner::Type(type_id) => match target_params.len() { + 0 => LuaType::Ref(type_id.clone()), + _ => LuaType::Generic(Arc::new(LuaGenericType::new( + type_id.clone(), + target_params.to_vec(), + ))), + }, + _ => return Err(InferFailReason::None), + }; + + let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; + if is_pairs_call(context).unwrap_or(false) + && try_handle_pairs_metamethod( + context, + table_params, + &members, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ) + .is_ok() + { + return Ok(()); + } + + let target_key_type = table_params[0].clone(); + let mut keys = Vec::with_capacity(members.len()); + let mut values = Vec::with_capacity(members.len()); + for (key, members) in members { + let key_type = match key { + LuaMemberKey::Integer(i) => LuaType::IntegerConst(i), + LuaMemberKey::Name(name) => LuaType::StringConst(name.clone().into()), + LuaMemberKey::ExprType(ty) => ty, + _ => continue, + }; + + if !target_key_type.is_generic() + && check_type_compact(context.db, &target_key_type, &key_type).is_err() + { + continue; + } + + keys.push(key_type); + values.push(member_infos_type(members)); + } + + if keys.is_empty() { + find_index_operations(context.db, &owner_type) + .ok_or(InferFailReason::None)? + .iter() + .for_each(|member| { + if target_key_type.is_generic() { + return; + } + let LuaMemberKey::ExprType(key_type) = &member.key else { + return; + }; + if check_type_compact(context.db, &target_key_type, key_type).is_ok() { + keys.push(key_type.clone()); + values.push(member.typ.clone()); + } + }); + } + + let key_type = match &keys[..] { + [] => return Err(InferFailReason::None), + [first] => first.clone(), + _ => LuaType::Union(LuaUnionType::from_vec(keys).into()), + }; + let value_type = match &values[..] { + [first] => first.clone(), + _ => LuaType::Union(LuaUnionType::from_vec(values).into()), + }; + + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &value_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ) +} + +fn generic_infer_types( + context: &mut InferenceContext, + source_generic: &LuaGenericType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Generic(target_generic) => { + let source_base = source_generic.get_base_type_id_ref(); + let target_base = target_generic.get_base_type_id_ref(); + if source_base == target_base { + for (start, (source_param, target_param)) in source_generic + .get_params() + .iter() + .zip(target_generic.get_params()) + .enumerate() + { + match source_param { + LuaType::Variadic(variadic) => { + variadic_infer_types( + context, + variadic, + &target_generic.get_params()[start..], + original_target, + variance, + priority, + )?; + break; + } + _ => infer_types_inner( + context, + source_param, + target_param, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?, + } + } + return Ok(()); + } + + let target_decl = context + .db + .get_type_index() + .get_type_decl(target_base) + .ok_or(InferFailReason::None)?; + if target_decl.is_alias() { + let substitutor = TypeSubstitutor::from_alias( + context.db, + target_generic.get_params().clone(), + target_base.clone(), + ); + if let Some(origin_type) = + target_decl.get_alias_origin(context.db, Some(&substitutor)) + { + return generic_infer_types( + context, + source_generic, + &origin_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ); + } + } else if let Some(super_types) = + context.db.get_type_index().get_super_types(target_base) + { + for mut super_type in super_types { + if super_type.contains_tpl_node() { + let substitutor = + TypeSubstitutor::from_type_array(target_generic.get_params().clone()); + super_type = + instantiate_type_generic(context.db, &super_type, &substitutor); + } + generic_infer_types( + context, + source_generic, + &super_type, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + )?; + } + } + } + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + infer_guard.check(type_id)?; + let type_decl = context + .db + .get_type_index() + .get_type_decl(type_id) + .ok_or(InferFailReason::None)?; + if let Some(origin_type) = type_decl.get_alias_origin(context.db, None) { + return generic_infer_types( + context, + source_generic, + &origin_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ); + } + + for super_type in context + .db + .get_type_index() + .get_super_types(type_id) + .unwrap_or_default() + { + generic_infer_types( + context, + source_generic, + &super_type, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + )?; + } + } + LuaType::Union(union) => { + for member in union.into_vec() { + generic_infer_types( + context, + source_generic, + &member, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + )?; + } + } + _ => { + let substitutor = TypeSubstitutor::new(); + let generic_ty = LuaType::Generic(source_generic.clone().into()); + let ty = instantiate_type_generic(context.db, &generic_ty, &substitutor); + if LuaType::from(source_generic.clone()) != ty { + infer_types_inner( + context, + &ty, + target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + } + + Ok(()) +} + +fn function_infer_types( + context: &mut InferenceContext, + source_func: &LuaFunctionType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::DocFunction(target_func) => { + function_doc_infer_types( + context, + source_func, + target_func, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Signature(signature_id) => { + let signature = context + .db + .get_signature_index() + .get(signature_id) + .ok_or(InferFailReason::None)?; + if !signature.is_resolve_return() { + return check_lambda_inference(context, *signature_id); + } + + let fake_func = signature.to_doc_func_type(); + function_doc_infer_types( + context, + source_func, + &fake_func, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +fn function_doc_infer_types( + context: &mut InferenceContext, + source_func: &LuaFunctionType, + target_func: &LuaFunctionType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let mut source_params = source_func.get_params().to_vec(); + if source_func.is_colon_define() { + source_params.insert(0, ("self".to_string(), Some(LuaType::Any))); + } + + let mut target_params = target_func.get_params().to_vec(); + if target_func.is_colon_define() { + target_params.insert(0, ("self".to_string(), Some(LuaType::Any))); + } + + param_list_infer_types( + context, + &source_params, + &target_params, + original_target, + variance.flip(), + priority, + arg_expr, + infer_guard, + )?; + return_type_infer_types( + context, + source_func.get_ret(), + target_func.get_ret(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + ) +} + +#[allow(clippy::too_many_arguments)] +fn param_list_infer_types( + context: &mut InferenceContext, + sources: &[(String, Option)], + targets: &[(String, Option)], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let mut target_offset = 0; + for i in 0..sources.len() { + let source = match sources.get(i) { + Some((_, ty)) => ty.clone().unwrap_or(LuaType::Any), + None => break, + }; + + match &source { + LuaType::Variadic(inner) => { + let i = i + target_offset; + if i >= targets.len() { + if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { + context.insert_type( + tpl_ref.get_tpl_id(), + InferenceCandidate::ordinary(LuaType::Nil), + variance, + false, + priority, + ); + } + break; + } + + if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() + && let Some(len) = context.inferred_variadic_len(tpl_ref.get_tpl_id()) + { + target_offset += len - 1; + continue; + } + + let mut target_rest_params = &targets[i..]; + if i + 1 < sources.len() { + let source_rest_len = sources.len() - i - 1; + if source_rest_len >= target_rest_params.len() { + continue; + } + let target_rest_len = target_rest_params.len() - source_rest_len; + target_rest_params = &target_rest_params[..target_rest_len]; + if target_rest_len > 1 { + target_offset += target_rest_len - 1; + } + } + + function_varargs_infer_types(context, inner, target_rest_params)?; + } + _ => { + let target = match targets.get(i + target_offset) { + Some((_, ty)) => ty.clone().unwrap_or(LuaType::Any), + None => break, + }; + infer_types_inner( + context, + &source, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub(in crate::semantic::generic) fn return_type_infer_types( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match (source, target) { + (LuaType::Variadic(source_variadic), LuaType::Variadic(target_variadic)) => { + match target_variadic.deref() { + VariadicType::Base(target_base) => match source_variadic.deref() { + VariadicType::Base(source_base) => { + if let LuaType::TplRef(tpl_ref) = source_base { + context.insert_type( + tpl_ref.get_tpl_id(), + InferenceCandidate::ordinary(target_base.clone()), + variance, + false, + priority, + ); + } + } + VariadicType::Multi(source_multi) => { + for ret_type in source_multi { + match ret_type { + LuaType::Variadic(inner) => { + if let VariadicType::Base(base) = inner.deref() + && let LuaType::TplRef(tpl_ref) = base + { + context.insert_type( + tpl_ref.get_tpl_id(), + InferenceCandidate::ordinary(target_base.clone()), + variance, + false, + priority, + ); + } + break; + } + LuaType::TplRef(tpl_ref) => { + context.insert_type( + tpl_ref.get_tpl_id(), + InferenceCandidate::ordinary(target_base.clone()), + variance, + false, + priority, + ); + } + _ => {} + } + } + } + }, + VariadicType::Multi(target_types) => { + variadic_infer_types( + context, + source_variadic, + target_types, + original_target, + variance, + priority, + )?; + } + } + } + (LuaType::Variadic(variadic), _) => { + variadic_infer_types( + context, + variadic, + std::slice::from_ref(target), + original_target, + variance, + priority, + )?; + } + (_, LuaType::Variadic(variadic)) => { + multi_param_infer_multi_return( + context, + std::slice::from_ref(source), + variadic, + original_target, + variance, + priority, + )?; + } + _ => infer_types_inner( + context, + source, + target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?, + } + + Ok(()) +} + +fn function_varargs_infer_types( + context: &mut InferenceContext, + variadic: &VariadicType, + target_rest_params: &[(String, Option)], +) -> Result<(), InferFailReason> { + if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = variadic { + context.add_variadic_params( + tpl_ref.get_tpl_id(), + target_rest_params + .iter() + .map(|(name, ty)| (name.clone(), ty.clone())) + .collect(), + ); + } + + Ok(()) +} + +pub(in crate::semantic::generic) fn variadic_infer_types( + context: &mut InferenceContext, + source: &VariadicType, + target_rest_types: &[LuaType], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + match source { + VariadicType::Base(base) => match base { + LuaType::TplRef(tpl_ref) => { + let tpl_id = tpl_ref.get_tpl_id(); + match target_rest_types.len() { + 0 => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Nil), + variance, + false, + priority, + ), + 1 => match &target_rest_types[0] { + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Multi(types) => match types.len() { + 0 => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Nil), + variance, + false, + priority, + ), + 1 => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(types[0].clone()), + variance, + false, + priority, + ), + _ => context.insert_multi_types( + tpl_id, + types.to_vec(), + InferenceCandidateView::Ordinary, + false, + priority, + ), + }, + VariadicType::Base(base) => { + context.add_variadic_base(tpl_id, base.clone()); + } + }, + target => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(target.clone()), + variance, + false, + priority, + ), + }, + _ => context.insert_multi_types( + tpl_id, + target_rest_types.to_vec(), + InferenceCandidateView::Ordinary, + false, + priority, + ), + } + } + LuaType::ConstTplRef(tpl_ref) => { + let tpl_id = tpl_ref.get_tpl_id(); + match target_rest_types.len() { + 0 => context.insert_type( + tpl_id, + InferenceCandidate::const_preserving(LuaType::Nil), + variance, + false, + priority, + ), + 1 => context.insert_type( + tpl_id, + InferenceCandidate::const_preserving(target_rest_types[0].clone()), + variance, + false, + priority, + ), + _ => context.insert_multi_types( + tpl_id, + target_rest_types.to_vec(), + InferenceCandidateView::ConstPreserving, + false, + priority, + ), + } + } + _ => {} + }, + VariadicType::Multi(multi) => { + for (i, ret_type) in multi.iter().enumerate() { + match ret_type { + LuaType::Variadic(inner) => { + if i < target_rest_types.len() { + variadic_infer_types( + context, + inner, + &target_rest_types[i..], + original_target, + variance, + priority, + )?; + } + break; + } + LuaType::TplRef(tpl_ref) => { + let Some(target) = target_rest_types.get(i) else { + break; + }; + context.insert_type( + tpl_ref.get_tpl_id(), + InferenceCandidate::ordinary(target.clone()), + variance, + target == original_target, + priority, + ); + } + _ => {} + } + } + } + } + + Ok(()) +} + +pub(in crate::semantic::generic) fn multi_param_infer_multi_return( + context: &mut InferenceContext, + source_params: &[LuaType], + multi_return: &VariadicType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + match multi_return { + VariadicType::Base(base) => { + let mut target_types = Vec::with_capacity(source_params.len()); + for param in source_params { + if param.is_variadic() { + target_types.push(LuaType::Variadic(multi_return.clone().into())); + break; + } else { + target_types.push(base.clone()); + } + } + infer_type_list( + context, + source_params, + &target_types, + original_target, + variance, + priority, + )?; + } + VariadicType::Multi(_) => { + let mut target_types = Vec::with_capacity(source_params.len()); + for (i, param) in source_params.iter().enumerate() { + let Some(return_type) = multi_return.get_type(i) else { + break; + }; + if param.is_variadic() { + target_types.push(LuaType::Variadic( + multi_return.get_new_variadic_from(i).into(), + )); + break; + } else { + target_types.push(return_type.clone()); + } + } + infer_type_list( + context, + source_params, + &target_types, + original_target, + variance, + priority, + )?; + } + } + + Ok(()) +} + +pub(in crate::semantic::generic) fn infer_type_list( + context: &mut InferenceContext, + source_types: &[LuaType], + target_types: &[LuaType], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + for (start, (source, target)) in source_types.iter().zip(target_types).enumerate() { + match (source, target) { + (LuaType::Variadic(variadic), _) => { + variadic_infer_types( + context, + variadic, + &target_types[start..], + original_target, + variance, + priority, + )?; + break; + } + (_, LuaType::Variadic(variadic)) => { + multi_param_infer_multi_return( + context, + &source_types[start..], + variadic, + original_target, + variance, + priority, + )?; + break; + } + _ => infer_types(context, source, target, original_target, variance, priority)?, + } + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn tuple_infer_types( + context: &mut InferenceContext, + source_tuple: &LuaTupleType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Tuple(target_tuple) => { + for (i, source_type) in source_tuple.get_types().iter().enumerate() { + if let LuaType::Variadic(inner) = source_type { + variadic_infer_types( + context, + inner, + &target_tuple.get_types()[i..], + original_target, + variance, + priority, + )?; + break; + } + let Some(target_type) = target_tuple.get_types().get(i) else { + break; + }; + infer_types_inner( + context, + source_type, + target_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + LuaType::Array(target_array) => { + let Some(last_type) = source_tuple.get_types().last() else { + return Err(InferFailReason::None); + }; + if let LuaType::Variadic(inner) = last_type + && let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() + { + context.add_variadic_base(tpl_ref.get_tpl_id(), target_array.get_base().clone()); + } + } + _ => {} + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn object_infer_types( + context: &mut InferenceContext, + source_obj: &crate::LuaObjectType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Object(target_obj) => { + for (key, value) in source_obj + .get_fields() + .iter() + .sorted_by_key(|(key, _)| *key) + { + if let Some(target_value) = target_obj.get_fields().get(key) { + infer_types_inner( + context, + value, + target_value, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + for (source_key, value) in source_obj.get_index_access() { + let target_access = target_obj + .get_index_access() + .iter() + .find(|(target_key, _)| { + check_type_compact(context.db, source_key, target_key).is_ok() + }); + if let Some((target_key, target_value)) = target_access { + infer_types_inner( + context, + source_key, + target_key, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + value, + target_value, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + } + LuaType::TableConst(inst) => { + object_member_owner_infer_types( + context, + source_obj, + LuaMemberOwner::Element(inst.clone()), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn object_member_owner_infer_types( + context: &mut InferenceContext, + source_obj: &crate::LuaObjectType, + owner: LuaMemberOwner, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let owner_type = match &owner { + LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), + LuaMemberOwner::Type(type_id) => LuaType::Ref(type_id.clone()), + _ => return Err(InferFailReason::None), + }; + + let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; + for (key, members) in members { + let resolve_type = member_infos_type(members); + if let Some(field_value) = source_obj.get_field(&key) { + infer_types_inner( + context, + field_value, + &resolve_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + + Ok(()) +} + +fn member_infos_type(members: Vec) -> LuaType { + match members.len() { + 0 => LuaType::Any, + 1 => members[0].typ.clone(), + _ => LuaType::from_vec(members.into_iter().map(|member| member.typ).collect()), + } +} + +fn is_pairs_call(context: &mut InferenceContext) -> Option { + let call_expr = context.call_expr.as_ref()?; + let prefix_expr = call_expr.get_prefix_expr()?; + let semantic_decl = match prefix_expr.syntax().clone().into() { + NodeOrToken::Node(node) => infer_node_semantic_decl( + context.db, + context.cache, + node, + SemanticDeclLevel::default(), + ), + _ => None, + }?; + + let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl else { + return None; + }; + let decl = context.db.get_decl_index().get_decl(&decl_id)?; + if !context.db.get_module_index().is_std(&decl.get_file_id()) { + return None; + } + if decl.get_name() != "pairs" { + return None; + } + + Some(true) +} + +#[allow(clippy::too_many_arguments)] +fn try_handle_pairs_metamethod( + context: &mut InferenceContext, + table_params: &[LuaType], + members: &StdHashMap>, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let pairs_member = members + .get(&LuaMemberKey::Name("__pairs".into())) + .ok_or(InferFailReason::None)? + .first() + .ok_or(InferFailReason::None)?; + + let meta_return = match &pairs_member.typ { + LuaType::Signature(signature_id) => context + .db + .get_signature_index() + .get(signature_id) + .map(|signature| signature.get_return_type()), + LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), + _ => None, + } + .ok_or(InferFailReason::None)?; + + let final_return_type = match meta_return { + LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), + LuaType::Signature(signature_id) => context + .db + .get_signature_index() + .get(&signature_id) + .map(|signature| signature.get_return_type()), + _ => None, + }; + + if let Some(LuaType::Variadic(variadic)) = &final_return_type { + let key_type = variadic.get_type(0).ok_or(InferFailReason::None)?; + let value_type = variadic.get_type(1).ok_or(InferFailReason::None)?; + infer_types_inner( + context, + &table_params[0], + key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + value_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + return Ok(()); + } + + Err(InferFailReason::None) +} + +fn check_lambda_inference( + context: &mut InferenceContext, + signature_id: crate::LuaSignatureId, +) -> Result<(), InferFailReason> { + let call_expr = context.call_expr.as_ref().ok_or(InferFailReason::None)?; + let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; + for arg in call_arg_list.get_args() { + if let Ok(LuaType::Signature(arg_signature_id)) = + crate::semantic::infer_expr(context.db, context.cache, arg.clone()) + && arg_signature_id == signature_id + { + return Ok(()); + } + } + + Err(InferFailReason::UnResolveSignatureReturn(signature_id)) +} + +fn resolve_info( + db: &DbIndex, + tpl: &GenericTpl, + info: &InferenceInfo, + return_top_level: bool, + substitutor: &TypeSubstitutor, +) -> Option { + if let Some(fixed) = &info.fixed { + return Some(fixed.clone()); + } + + if !info.multi.is_empty() { + let primitive_constraint = tpl.get_constraint().is_some_and(|constraint| { + let constraint = instantiate_type_generic(db, constraint, substitutor); + is_primitive_or_literal_type(&constraint) + }); + let const_preserving = info + .multi + .iter() + .any(|record| record.candidate.is_const_preserving()); + let preserve_root_literal_form = + primitive_constraint || const_preserving || return_top_level; + return Some(InferenceResult::MultiTypes( + info.multi + .iter() + .map(|record| { + resolve_candidate_type(db, &record.candidate, preserve_root_literal_form) + }) + .collect(), + )); + } + + if !info.covariant.is_empty() { + return Some(InferenceResult::Type(resolve_covariant_candidates( + db, + tpl, + info, + return_top_level, + substitutor, + ))); + } + + if !info.contravariant.is_empty() { + return Some(InferenceResult::Type(resolve_contravariant_candidates( + db, info, + ))); + } + + None +} + +fn write_result_to_substitutor( + substitutor: &mut TypeSubstitutor, + tpl_id: GenericTplId, + result: InferenceResult, +) { + match result { + InferenceResult::Type(ty) => substitutor.bind_type(tpl_id, ty), + InferenceResult::MultiTypes(types) => { + substitutor.bind(tpl_id, TplBinding::InferredMultiTypes(types)); + } + InferenceResult::VariadicParams(params) => { + substitutor.bind(tpl_id, TplBinding::VariadicParams(params)); + } + InferenceResult::VariadicBase(base) => { + substitutor.bind(tpl_id, TplBinding::VariadicBase(base)); + } + } +} + +fn resolve_covariant_candidates( + db: &DbIndex, + tpl: &GenericTpl, + info: &InferenceInfo, + return_top_level: bool, + substitutor: &TypeSubstitutor, +) -> LuaType { + let primitive_constraint = tpl + .get_constraint() + .map(|constraint| { + let constraint = instantiate_type_generic(db, constraint, substitutor); + is_primitive_or_literal_type(&constraint) + }) + .unwrap_or(false); + let const_preserving = info + .covariant + .iter() + .any(|record| record.candidate.is_const_preserving()); + let preserve_root_literal_form = + primitive_constraint || const_preserving || !info.top_level || return_top_level; + + combine_records( + db, + &info.covariant, + info.priority.unwrap_or(InferencePriority::Normal), + preserve_root_literal_form, + TypeOps::Union, + ) +} + +fn resolve_contravariant_candidates(db: &DbIndex, info: &InferenceInfo) -> LuaType { + combine_records( + db, + &info.contravariant, + info.priority.unwrap_or(InferencePriority::Normal), + true, + TypeOps::Intersect, + ) +} + +fn combine_records( + db: &DbIndex, + records: &[InferenceRecord], + max_priority: InferencePriority, + preserve_root_literal_form: bool, + op: TypeOps, +) -> LuaType { + let mut selected = records + .iter() + .filter(|record| record.priority == max_priority) + .map(|record| resolve_candidate_type(db, &record.candidate, preserve_root_literal_form)); + + let Some(first) = selected.next() else { + return LuaType::Unknown; + }; + + selected.fold(first, |acc, ty| op.apply(db, &acc, &ty)) +} + +fn resolve_candidate_type( + db: &DbIndex, + candidate: &InferenceCandidate, + preserve_root_literal_form: bool, +) -> LuaType { + if candidate.is_const_preserving() { + // std.ConstTpl 需要保留结构字面量, 例如 tuple/object/table const. + return candidate.candidate_type(); + } + + if preserve_root_literal_form { + return regularize_tpl_candidate_type(db, candidate.candidate_type()); + } + + match candidate.view { + InferenceCandidateView::FreshExpression | InferenceCandidateView::Ordinary => { + widen_tpl_candidate_type(db, candidate.candidate_type()) + } + _ => regularize_tpl_candidate_type(db, candidate.candidate_type()), + } +} + +pub(in crate::semantic::generic) fn is_literal_candidate(ty: &LuaType) -> bool { + match ty { + LuaType::StringConst(_) + | LuaType::DocStringConst(_) + | LuaType::IntegerConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::BooleanConst(_) + | LuaType::DocBooleanConst(_) + | LuaType::TableConst(_) => true, + LuaType::Union(union) => union.into_vec().iter().any(is_literal_candidate), + LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_literal_candidate), + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => is_literal_candidate(base), + VariadicType::Multi(types) => types.iter().any(is_literal_candidate), + }, + _ => false, + } +} + +fn is_fresh_literal_expr(expr: &LuaExpr) -> bool { + match expr { + LuaExpr::LiteralExpr(_) | LuaExpr::TableExpr(_) => true, + LuaExpr::ParenExpr(paren) => paren + .get_expr() + .is_some_and(|expr| is_fresh_literal_expr(&expr)), + _ => false, + } +} + +fn get_str_tpl_infer_type(name: &str) -> LuaType { + match name { + "unknown" => LuaType::Unknown, + "never" => LuaType::Never, + "nil" | "void" => LuaType::Nil, + "any" => LuaType::Any, + "userdata" => LuaType::Userdata, + "thread" => LuaType::Thread, + "boolean" | "bool" => LuaType::Boolean, + "string" => LuaType::String, + "integer" | "int" => LuaType::Integer, + "number" => LuaType::Number, + "io" => LuaType::Io, + "self" => LuaType::SelfInfer, + "global" => LuaType::Global, + "function" => LuaType::Function, + _ => LuaType::Ref(LuaTypeDeclId::global(name)), + } +} + +fn escape_alias(db: &DbIndex, may_alias: &LuaType) -> LuaType { + if let LuaType::Ref(type_id) = may_alias + && let Some(type_decl) = db.get_type_index().get_type_decl(type_id) + && type_decl.is_alias() + && let Some(origin_type) = type_decl.get_alias_origin(db, None) + { + return origin_type.clone(); + } + + may_alias.clone() +} + +fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { + is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) +} + +fn is_tpl_at_top_level_with_guard( + db: &DbIndex, + ty: &LuaType, + tpl_id: GenericTplId, + visited_aliases: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::Union(union) => union.into_vec().iter().any(|member| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::Generic(generic) => { + let type_decl_id = generic.get_base_type_id_ref(); + let Some(alias_param) = + get_transparent_alias_param_index(db, type_decl_id, visited_aliases) + else { + return false; + }; + + generic.get_params().get(alias_param).is_some_and(|param| { + is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) + }) + } + _ => false, + } +} + +fn get_transparent_alias_param_index( + db: &DbIndex, + type_decl_id: &LuaTypeDeclId, + visited_aliases: &mut HashSet, +) -> Option { + if !visited_aliases.insert(type_decl_id.clone()) { + return None; + } + + let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; + if !type_decl.is_alias() { + return None; + }; + let origin = type_decl.get_alias_ref()?; + + match origin { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + LuaType::Generic(generic) => { + get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) + .and_then(|alias_param| generic.get_params().get(alias_param)) + .and_then(|param| match param { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + _ => None, + }) + } + _ => None, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs index b58290125..2ae7fd3d7 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs @@ -12,11 +12,9 @@ use crate::{ semantic::{ LuaInferCache, generic::{ - tpl_context::TplContext, - tpl_pattern::{ - multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, - tpl_pattern_match, variadic_tpl_pattern_match, - }, + InferenceContext, InferencePriority, InferenceVariance, infer_type_list, + infer_types_from_expr, multi_param_infer_multi_return, return_type_infer_types, + variadic_infer_types, }, infer::InferFailReason, infer_expr, @@ -27,7 +25,7 @@ use crate::{ }; use crate::{ GenericTpl, LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, - infer_node_semantic_decl, tpl_pattern_match_args_skip_unknown, + infer_node_semantic_decl, }; use super::{TypeSubstitutor, instantiate_type_generic}; @@ -54,9 +52,9 @@ pub fn infer_call_func_generic( .collect::>(); let mut substitutor = TypeSubstitutor::new(); { - let mut context = TplContext::new(db, cache, &mut substitutor, Some(call_expr.clone())); + let mut context = InferenceContext::new(db, cache, Some(call_expr.clone())); if !generic_tpls.is_empty() { - context.substitutor.prepare_inference_slots(generic_tpls); + context.prepare_inference_slots(generic_tpls); if let Some(type_list) = call_expr.get_call_generic_type_list() { // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 @@ -73,13 +71,14 @@ pub fn infer_call_func_generic( )?; } } + + let func_generic_tpls = func_generic_tpls(func); + context.bridge_to_substitutor(&mut substitutor, func_generic_tpls.iter(), func.get_ret()); } if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { substitutor.add_self_type(self_type); } - substitutor.finalize_inferred_types(db, func_generic_tpls(func).iter(), func.get_ret()); - let func_ty = LuaType::DocFunction(func.clone().into()); if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_ty, &substitutor) { Ok(f.deref().clone()) @@ -91,15 +90,13 @@ pub fn infer_call_func_generic( fn apply_call_generic_type_list( db: &DbIndex, file_id: FileId, - context: &mut TplContext, + context: &mut InferenceContext, type_list: &LuaDocTypeList, ) { let doc_ctx = DocTypeInferContext::new(db, file_id); for (i, doc_type) in type_list.get_types().enumerate() { let typ = infer_doc_type(doc_ctx, &doc_type); - context - .substitutor - .bind_type(GenericTplId::Func(i as u32), typ); + context.fix_type(GenericTplId::Func(i as u32), typ); } } @@ -120,7 +117,7 @@ fn as_doc_function_type( } fn infer_callable_return_from_arg_types( - context: &mut TplContext, + context: &mut InferenceContext, callable_type: &LuaType, call_arg_types: &[LuaType], ) -> Result, InferFailReason> { @@ -197,7 +194,7 @@ fn uses_erased_function_param(callable: &LuaFunctionType, call_arg_types: &[LuaT } fn infer_callable_return_from_remaining_args( - context: &mut TplContext, + context: &mut InferenceContext, callable_type: &LuaType, arg_exprs: &[LuaExpr], ) -> Result, InferFailReason> { @@ -227,7 +224,7 @@ fn infer_callable_return_from_remaining_args( } fn instantiate_callable_from_arg_types( - context: &mut TplContext, + context: &mut InferenceContext, callable: &Arc, call_arg_types: &[LuaType], ) -> Option> { @@ -250,31 +247,32 @@ fn instantiate_callable_from_arg_types( .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); + let callable_generic_tpls = callable_generic_tpls(callable); let mut callable_substitutor = TypeSubstitutor::new(); callable_substitutor.prepare_inference_slots(callable_tpls.clone()); { - let mut callable_context = TplContext::new( - context.db, - context.cache, - &mut callable_substitutor, - context.call_expr.clone(), - ); - if tpl_pattern_match_args_skip_unknown( + let mut callable_context = + InferenceContext::new(context.db, context.cache, context.call_expr.clone()); + callable_context.prepare_inference_slots(callable_tpls.clone()); + if infer_type_list( &mut callable_context, &callable_param_types, call_arg_types, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, ) .is_err() { return None; } + callable_context.bridge_resolved_to_substitutor( + &mut callable_substitutor, + callable_generic_tpls.iter(), + callable.get_ret(), + ); } - callable_substitutor.finalize_inferred_types( - context.db, - callable_generic_tpls(callable).iter(), - callable.get_ret(), - ); let callable_ty = LuaType::DocFunction(callable.clone()); let instantiated = match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { @@ -492,7 +490,7 @@ fn collect_func_tpl_dep_from_fallback_type( fn infer_generic_types_from_call( db: &DbIndex, - context: &mut TplContext, + context: &mut InferenceContext, func: &LuaFunctionType, call_expr: &LuaCallExpr, func_params: &mut Vec<(String, LuaType)>, @@ -518,7 +516,7 @@ fn infer_generic_types_from_call( break; } - if !context.substitutor.has_unresolved_inference_slots() { + if !context.has_unresolved_inference_slots() { break; } @@ -547,13 +545,27 @@ fn infer_generic_types_from_call( if let Some(inferred_return_type) = infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? { - return_type_pattern_match_target_type( + return_type_infer_types( context, &return_pattern, &inferred_return_type, + &inferred_return_type, + InferenceVariance::Covariant, + InferencePriority::Return, + None, + &crate::InferGuard::new(), )?; } else if arg_type.is_any() || arg_type.is_unknown() { - return_type_pattern_match_target_type(context, &return_pattern, &LuaType::Unknown)?; + return_type_infer_types( + context, + &return_pattern, + &LuaType::Unknown, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + None, + &crate::InferGuard::new(), + )?; } } @@ -564,7 +576,14 @@ fn infer_generic_types_from_call( let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; arg_types.push(arg_type); } - variadic_tpl_pattern_match(context, variadic, &arg_types)?; + variadic_infer_types( + context, + variadic, + &arg_types, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + )?; break; } (_, LuaType::Variadic(variadic)) => { @@ -573,20 +592,39 @@ fn infer_generic_types_from_call( .map(|(_, t)| t) .cloned() .collect::>(); - multi_param_tpl_pattern_match_multi_return(context, &func_param_types, variadic)?; + multi_param_infer_multi_return( + context, + &func_param_types, + variadic, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + )?; break; } _ => { - tpl_pattern_match(context, func_param_type, &arg_type)?; + infer_types_from_expr( + context, + func_param_type, + &arg_type, + &arg_type, + call_arg_expr, + )?; } } } - if context.substitutor.has_unresolved_inference_slots() { + if context.has_unresolved_inference_slots() { for (func_param_type, call_arg_expr) in unresolve_tpls { - let closure_type = infer_expr(db, context.cache, call_arg_expr)?; - - tpl_pattern_match(context, &func_param_type, &closure_type)?; + let closure_type = infer_expr(db, context.cache, call_arg_expr.clone())?; + + infer_types_from_expr( + context, + &func_param_type, + &closure_type, + &closure_type, + &call_arg_expr, + )?; } } @@ -680,7 +718,7 @@ pub fn infer_self_type( } fn check_expr_can_later_infer( - context: &mut TplContext, + context: &mut InferenceContext, func_param_type: &LuaType, call_arg_expr: &LuaExpr, ) -> Result { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs index 0764615b9..4ec342ace 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs @@ -4,52 +4,12 @@ use hashbrown::{HashMap, HashSet}; use rowan::TextRange; use crate::{ - DbIndex, GenericParam, GenericTpl, InFiled, LuaArrayType, LuaConditionalType, LuaFunctionType, + DbIndex, GenericParam, InFiled, LuaArrayType, LuaConditionalType, LuaFunctionType, LuaGenericType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, - LuaType, LuaUnionType, TypeOps, TypeSubstitutor, VariadicType, instantiate_type_generic, + LuaType, LuaUnionType, TypeOps, VariadicType, }; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(in crate::semantic::generic) enum TplCandidateSource { - Plain, - ConstPreserving, - Finalized, -} - -pub(in crate::semantic::generic) fn finalize_inferred_tpl_candidate( - db: &DbIndex, - tpl: &GenericTpl, - raw_candidate: &LuaType, - candidate_source: TplCandidateSource, - top_level: bool, - return_top_level: bool, - substitutor: &TypeSubstitutor, -) -> LuaType { - if candidate_source == TplCandidateSource::ConstPreserving { - return raw_candidate.clone(); - } - - let primitive_constraint = tpl - .get_constraint() - .map(|constraint| { - let constraint = instantiate_type_generic(db, constraint, substitutor); - is_primitive_or_literal_type(&constraint) - }) - .unwrap_or(false); - let candidate = if primitive_constraint || !top_level || return_top_level { - raw_candidate.clone() - } else { - widen_literal_type(raw_candidate.clone()) - }; - finalize_tpl_candidate_type( - db, - candidate, - WideningContext::Root, - &mut WideningGuard::default(), - ) -} - -fn is_primitive_or_literal_type(ty: &LuaType) -> bool { +pub(in crate::semantic::generic) fn is_primitive_or_literal_type(ty: &LuaType) -> bool { match ty { LuaType::String | LuaType::Number @@ -78,8 +38,9 @@ fn is_primitive_or_literal_type(ty: &LuaType) -> bool { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WideningContext { +enum WideningContext { Root, + RootUnionMember, UnionMember, ObjectProperty, ArrayElement, @@ -90,7 +51,7 @@ pub enum WideningContext { const MAX_WIDENING_DEPTH: u16 = 100; #[derive(Default)] -pub struct WideningGuard { +struct WideningGuard { depth: u16, active_table_ids: HashSet>, } @@ -117,249 +78,225 @@ impl WideningGuard { } } -fn finalize_tpl_candidate_type( - db: &DbIndex, - ty: LuaType, - context: WideningContext, - guard: &mut WideningGuard, -) -> LuaType { - if !guard.enter_level() { - return match ty { - LuaType::TableConst(_) => LuaType::Table, - ty => widen_literals_with_context(ty, context), - }; - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum RootPrimitiveBehavior { + PreserveLiteral, + WidenLiteral, +} - let widened = match ty { - LuaType::TableConst(table_id) => { - table_const_to_object(db, table_id, guard).unwrap_or(LuaType::Table) - } - LuaType::Object(object) => { - let fields = object - .get_fields() - .iter() - .map(|(key, ty)| { - ( - key.clone(), - finalize_tpl_candidate_type( - db, - ty.clone(), - WideningContext::ObjectProperty, - guard, - ), - ) - }) - .collect(); - let index_access = object - .get_index_access() - .iter() - .map(|(key, value)| { - ( - widen_type_with_context( - key.clone(), - WideningContext::ObjectProperty, - guard, - ), - finalize_tpl_candidate_type( - db, - value.clone(), - WideningContext::ObjectProperty, - guard, - ), - ) - }) - .collect(); - LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) - } - LuaType::Array(array) => { - let element_context = match context { - WideningContext::TupleElement => WideningContext::TupleElement, - _ => WideningContext::ArrayElement, - }; - let base = - finalize_tpl_candidate_type(db, array.get_base().clone(), element_context, guard); - LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) +struct WideningTransformer<'db> { + db: Option<&'db DbIndex>, + root_primitive_behavior: RootPrimitiveBehavior, + guard: WideningGuard, +} + +impl<'db> WideningTransformer<'db> { + fn new(db: Option<&'db DbIndex>, root_primitive_behavior: RootPrimitiveBehavior) -> Self { + Self { + db, + root_primitive_behavior, + guard: WideningGuard::default(), } - LuaType::Tuple(tuple) => { - let types = tuple - .get_types() - .iter() - .cloned() - .map(|ty| finalize_tpl_candidate_type(db, ty, WideningContext::TupleElement, guard)) - .collect(); - LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + + fn for_candidate_regularization(db: &'db DbIndex) -> Self { + Self::new(Some(db), RootPrimitiveBehavior::PreserveLiteral) + } + + fn for_candidate_widening(db: &'db DbIndex) -> Self { + Self::new(Some(db), RootPrimitiveBehavior::WidenLiteral) + } + + fn transform(&mut self, ty: LuaType, context: WideningContext) -> LuaType { + if !self.guard.enter_level() { + return self.fallback(ty, context); } - LuaType::Union(union) => { - let member_context = if matches!(context, WideningContext::Root) { - WideningContext::Root - } else { - WideningContext::UnionMember - }; - LuaType::Union( - LuaUnionType::from_vec( - union - .into_vec() - .into_iter() - .map(|ty| finalize_tpl_candidate_type(db, ty, member_context, guard)) - .collect(), - ) - .into(), - ) + + let widened = match ty { + LuaType::TableConst(table_id) => self.transform_table_const(table_id), + LuaType::Array(array) => self.transform_array(array, context), + LuaType::Tuple(tuple) => self.transform_tuple(tuple), + LuaType::Object(object) => self.transform_object(object), + LuaType::Union(union) => self.transform_union(union, context), + LuaType::MultiLineUnion(multi) => self.transform_multi_line_union(multi, context), + LuaType::Intersection(intersection) => { + self.transform_intersection(intersection, context) + } + LuaType::Variadic(variadic) => self.transform_variadic(variadic), + LuaType::Generic(generic) => self.transform_generic(generic), + LuaType::TableGeneric(params) => self.transform_table_generic(params), + LuaType::DocFunction(func) => self.transform_doc_function(func), + LuaType::TypeGuard(type_guard) => self.transform_type_guard(type_guard), + LuaType::Conditional(conditional) => self.transform_conditional(conditional), + LuaType::Mapped(mapped) => self.transform_mapped(mapped), + ty => self.transform_terminal(ty, context), + }; + + self.guard.leave_level(); + widened + } + + fn fallback(&self, ty: LuaType, context: WideningContext) -> LuaType { + match (self.db, ty) { + (Some(_), LuaType::TableConst(_)) => LuaType::Table, + (_, ty) => self.transform_terminal(ty, context), } - ty => widen_type_with_context(ty, context, guard), - }; + } - guard.leave_level(); - widened -} + fn transform_table_const(&mut self, table_id: InFiled) -> LuaType { + let Some(db) = self.db else { + return LuaType::TableConst(table_id); + }; -pub fn widen_type_with_context( - ty: LuaType, - context: WideningContext, - guard: &mut WideningGuard, -) -> LuaType { - if !guard.enter_level() { - return widen_literals_with_context(ty, context); + self.table_const_to_object(db, table_id) + .unwrap_or(LuaType::Table) } - let ty = widen_literals_with_context(ty, context); + fn transform_array(&mut self, array: Arc, context: WideningContext) -> LuaType { + let element_context = match context { + WideningContext::TupleElement => WideningContext::TupleElement, + _ => WideningContext::ArrayElement, + }; + let base = self.transform(array.get_base().clone(), element_context); + LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) + } - let widened = match ty { - LuaType::Array(array) => { - let element_context = match context { - WideningContext::TupleElement => WideningContext::TupleElement, - _ => WideningContext::ArrayElement, - }; - let base = widen_type_with_context(array.get_base().clone(), element_context, guard); - LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) - } - LuaType::Tuple(tuple) => { - let types = tuple - .get_types() - .iter() - .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::TupleElement, guard)) - .collect(); - LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) - } - LuaType::Object(object) => { - let fields = object - .get_fields() - .iter() - .map(|(key, ty)| { - ( - key.clone(), - widen_type_with_context(ty.clone(), WideningContext::ObjectProperty, guard), - ) - }) - .collect(); - let index_access = object - .get_index_access() - .iter() - .map(|(key, value)| { - ( - widen_type_with_context( - key.clone(), - WideningContext::ObjectProperty, - guard, - ), - widen_type_with_context( - value.clone(), - WideningContext::ObjectProperty, - guard, - ), - ) - }) - .collect(); - LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) - } - LuaType::Union(union) => { - let member_context = if matches!(context, WideningContext::Root) { - WideningContext::Root - } else { - WideningContext::UnionMember - }; - LuaType::Union( - LuaUnionType::from_vec( - union - .into_vec() - .into_iter() - .map(|ty| widen_type_with_context(ty, member_context, guard)) - .collect(), + fn transform_tuple(&mut self, tuple: Arc) -> LuaType { + let types = tuple + .get_types() + .iter() + .cloned() + .map(|ty| self.transform(ty, WideningContext::TupleElement)) + .collect(); + LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + + fn transform_object(&mut self, object: Arc) -> LuaType { + let fields = object + .get_fields() + .iter() + .map(|(key, ty)| { + ( + key.clone(), + self.transform(ty.clone(), WideningContext::ObjectProperty), ) - .into(), + }) + .collect(); + let index_access = object + .get_index_access() + .iter() + .map(|(key, value)| { + ( + self.transform(key.clone(), WideningContext::ObjectProperty), + self.transform(value.clone(), WideningContext::ObjectProperty), + ) + }) + .collect(); + LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) + } + + fn transform_union(&mut self, union: Arc, context: WideningContext) -> LuaType { + let member_context = self.union_member_context(context); + LuaType::Union( + LuaUnionType::from_vec( + union + .into_vec() + .into_iter() + .map(|ty| self.transform(ty, member_context)) + .collect(), ) - } - LuaType::MultiLineUnion(multi) => LuaType::MultiLineUnion( + .into(), + ) + } + + fn transform_multi_line_union( + &mut self, + multi: Arc, + context: WideningContext, + ) -> LuaType { + let member_context = self.union_member_context(context); + LuaType::MultiLineUnion( crate::LuaMultiLineUnion::new( multi .get_unions() .iter() .map(|(ty, description)| { ( - widen_type_with_context( - ty.clone(), - WideningContext::UnionMember, - guard, - ), + self.transform(ty.clone(), member_context), description.clone(), ) }) .collect(), ) .into(), - ), - LuaType::Intersection(intersection) => LuaType::Intersection( + ) + } + + fn transform_intersection( + &mut self, + intersection: Arc, + context: WideningContext, + ) -> LuaType { + let member_context = self.union_member_context(context); + LuaType::Intersection( crate::LuaIntersectionType::new( intersection .get_types() .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::UnionMember, guard)) + .map(|ty| self.transform(ty, member_context)) .collect(), ) .into(), - ), - LuaType::Variadic(variadic) => LuaType::Variadic( + ) + } + + fn transform_variadic(&mut self, variadic: Arc) -> LuaType { + LuaType::Variadic( match variadic.deref() { - VariadicType::Base(base) => VariadicType::Base(widen_type_with_context( - base.clone(), - WideningContext::VariadicElement, - guard, - )), + VariadicType::Base(base) => VariadicType::Base( + self.transform(base.clone(), WideningContext::VariadicElement), + ), VariadicType::Multi(types) => VariadicType::Multi( types .iter() .cloned() - .map(|ty| { - widen_type_with_context(ty, WideningContext::VariadicElement, guard) - }) + .map(|ty| self.transform(ty, WideningContext::VariadicElement)) .collect(), ), } .into(), - ), - LuaType::Generic(generic) => LuaType::Generic( + ) + } + + fn transform_generic(&mut self, generic: Arc) -> LuaType { + LuaType::Generic( LuaGenericType::new( generic.get_base_type_id(), generic .get_params() .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)) + .map(|ty| self.transform(ty, WideningContext::Root)) .collect(), ) .into(), - ), - LuaType::TableGeneric(params) => LuaType::TableGeneric( + ) + } + + fn transform_table_generic(&mut self, params: Arc>) -> LuaType { + LuaType::TableGeneric( params .iter() .cloned() - .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)) + .map(|ty| self.transform(ty, WideningContext::Root)) .collect::>() .into(), - ), - LuaType::DocFunction(func) => LuaType::DocFunction( + ) + } + + fn transform_doc_function(&mut self, func: Arc) -> LuaType { + LuaType::DocFunction( LuaFunctionType::new( func.get_async_state(), func.is_colon_define(), @@ -369,48 +306,46 @@ pub fn widen_type_with_context( .map(|(name, ty)| { ( name.clone(), - ty.clone().map(|ty| { - widen_type_with_context(ty, WideningContext::Root, guard) - }), + ty.clone() + .map(|ty| self.transform(ty, WideningContext::Root)), ) }) .collect(), - widen_type_with_context(func.get_ret().clone(), WideningContext::Root, guard), + self.transform(func.get_ret().clone(), WideningContext::Root), ) .into(), - ), - LuaType::TypeGuard(type_guard) => LuaType::TypeGuard( - widen_type_with_context(type_guard.deref().clone(), WideningContext::Root, guard) + ) + } + + fn transform_type_guard(&mut self, type_guard: Arc) -> LuaType { + LuaType::TypeGuard( + self.transform(type_guard.deref().clone(), WideningContext::Root) .into(), - ), - LuaType::Conditional(conditional) => LuaType::Conditional( + ) + } + + fn transform_conditional(&mut self, conditional: Arc) -> LuaType { + LuaType::Conditional( LuaConditionalType::new( - widen_type_with_context( + self.transform( conditional.get_checked_type().clone(), WideningContext::Root, - guard, ), - widen_type_with_context( + self.transform( conditional.get_extends_type().clone(), WideningContext::Root, - guard, - ), - widen_type_with_context( - conditional.get_true_type().clone(), - WideningContext::Root, - guard, - ), - widen_type_with_context( - conditional.get_false_type().clone(), - WideningContext::Root, - guard, ), + self.transform(conditional.get_true_type().clone(), WideningContext::Root), + self.transform(conditional.get_false_type().clone(), WideningContext::Root), conditional.get_infer_params().to_vec(), conditional.has_new, ) .into(), - ), - LuaType::Mapped(mapped) => LuaType::Mapped(Arc::new(LuaMappedType::new( + ) + } + + fn transform_mapped(&mut self, mapped: Arc) -> LuaType { + LuaType::Mapped(Arc::new(LuaMappedType::new( ( mapped.param.0, GenericParam::new( @@ -420,87 +355,122 @@ pub fn widen_type_with_context( .1 .type_constraint .clone() - .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)), + .map(|ty| self.transform(ty, WideningContext::Root)), mapped .param .1 .default_type .clone() - .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)), + .map(|ty| self.transform(ty, WideningContext::Root)), mapped.param.1.attributes.clone(), ), ), - widen_type_with_context(mapped.value.clone(), WideningContext::Root, guard), + self.transform(mapped.value.clone(), WideningContext::Root), mapped.is_readonly, mapped.is_optional, - ))), - ty => ty, - }; + ))) + } - guard.leave_level(); - widened -} + fn transform_terminal(&self, ty: LuaType, context: WideningContext) -> LuaType { + // Keep a top-level literal union intact. Widening `"a" | "b"` to `string` + // would throw away a deliberate literal candidate during inference. + if matches!(context, WideningContext::RootUnionMember) { + return ty; + } -fn widen_literals_with_context(ty: LuaType, context: WideningContext) -> LuaType { - match context { - WideningContext::Root => ty, - _ => widen_literal_type(ty), - } -} + if matches!(context, WideningContext::Root) + && matches!( + self.root_primitive_behavior, + RootPrimitiveBehavior::PreserveLiteral + ) + { + return ty; + } -fn widen_literal_type(ty: LuaType) -> LuaType { - match ty { - LuaType::FloatConst(_) => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, - LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, - ty => ty, + widen_primitive_literal(ty) } -} -fn table_const_to_object( - db: &DbIndex, - table_id: InFiled, - guard: &mut WideningGuard, -) -> Option { - let owner = LuaMemberOwner::Element(table_id.clone()); - let members = db.get_member_index().get_members(&owner)?; - if !guard.enter_table(&table_id) { - return Some(LuaType::Table); + fn union_member_context(&self, context: WideningContext) -> WideningContext { + if matches!(context, WideningContext::Root) { + WideningContext::RootUnionMember + } else { + WideningContext::UnionMember + } } - let mut fields = HashMap::new(); - let mut index_access = Vec::new(); - - for member in members { - let value = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|cache| cache.as_type().clone()) - .unwrap_or(LuaType::Unknown); - let value = finalize_tpl_candidate_type(db, value, WideningContext::ObjectProperty, guard); - - match member.get_key() { - LuaMemberKey::Name(_) | LuaMemberKey::Integer(_) => { - fields - .entry(member.get_key().clone()) - .and_modify(|prev| { - *prev = TypeOps::Union.apply(db, prev, &value); - }) - .or_insert(value); + + fn table_const_to_object( + &mut self, + db: &DbIndex, + table_id: InFiled, + ) -> Option { + if !self.guard.enter_table(&table_id) { + return Some(LuaType::Table); + } + + let owner = LuaMemberOwner::Element(table_id.clone()); + let members = match db.get_member_index().get_members(&owner) { + Some(members) => members, + None => { + self.guard.leave_table(&table_id); + return None; } - LuaMemberKey::ExprType(key) => { - index_access.push(( - widen_type_with_context(key.clone(), WideningContext::ObjectProperty, guard), - value, - )); + }; + let mut fields = HashMap::with_capacity(members.len()); + let mut index_access = Vec::with_capacity(members.len()); + + for member in members { + let value = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Unknown); + let value = self.transform(value, WideningContext::ObjectProperty); + + match member.get_key() { + LuaMemberKey::Name(_) | LuaMemberKey::Integer(_) => { + let member_key = member.get_key().clone(); + fields + .entry(member_key) + .and_modify(|prev| { + *prev = TypeOps::Union.apply(db, prev, &value); + }) + .or_insert(value); + } + LuaMemberKey::ExprType(key) => { + index_access.push(( + self.transform(key.clone(), WideningContext::ObjectProperty), + value, + )); + } + LuaMemberKey::None => {} } - LuaMemberKey::None => {} } + + self.guard.leave_table(&table_id); + + Some(LuaType::Object( + LuaObjectType::new_with_fields(fields, index_access).into(), + )) } +} + +pub(in crate::semantic::generic) fn regularize_tpl_candidate_type( + db: &DbIndex, + ty: LuaType, +) -> LuaType { + WideningTransformer::for_candidate_regularization(db).transform(ty, WideningContext::Root) +} - guard.leave_table(&table_id); +pub(in crate::semantic::generic) fn widen_tpl_candidate_type(db: &DbIndex, ty: LuaType) -> LuaType { + WideningTransformer::for_candidate_widening(db).transform(ty, WideningContext::Root) +} - Some(LuaType::Object( - LuaObjectType::new_with_fields(fields, index_access).into(), - )) +fn widen_primitive_literal(ty: LuaType) -> LuaType { + match ty { + LuaType::FloatConst(_) => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, + LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, + ty => ty, + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index bfb24ad50..32a0d1f7c 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -1,17 +1,13 @@ use hashbrown::{HashMap, HashSet}; -use internment::ArcIntern; use crate::{ - DbIndex, GenericTpl, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, + DbIndex, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, check_type_compact, db_index::{LuaObjectType, LuaTupleType, LuaType}, semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; -use super::{ - TplCandidateSource, finalize_inferred_tpl_candidate, get_default_constructor, - instantiate_type_generic_inner, -}; +use super::{get_default_constructor, instantiate_type_generic_inner}; use crate::semantic::generic::type_substitutor::{ GenericInstantiateContext, GenericInstantiateFrame, TplBinding, }; @@ -89,7 +85,7 @@ fn instantiate_conditional_once( context, frame, conditional, - finalize_infer_assignments(context, conditional, infer_assignments), + finalize_infer_assignments(infer_assignments), ) } else if is_deferred_conditional_operand(&left_type) || right_type.any_type(|inner| match inner { @@ -733,8 +729,6 @@ fn insert_infer_assignment( } fn finalize_infer_assignments( - context: &GenericInstantiateContext, - conditional: &LuaConditionalType, assignments: HashMap, ) -> HashMap { assignments @@ -743,30 +737,7 @@ fn finalize_infer_assignments( candidates .covariant .or(candidates.contravariant) - .map(|raw_candidate| { - let Some(param) = conditional.get_infer_params().get(tpl_id.get_idx()) else { - return (tpl_id, raw_candidate); - }; - - let tpl = GenericTpl::new( - tpl_id, - ArcIntern::new(param.name.clone()), - param.type_constraint.clone(), - param.default_type.clone(), - ); - ( - tpl_id, - finalize_inferred_tpl_candidate( - context.db, - &tpl, - &raw_candidate, - TplCandidateSource::ConstPreserving, - true, - true, - context.substitutor, - ), - ) - }) + .map(|raw_candidate| (tpl_id, raw_candidate)) }) .collect() } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index c15dd4c44..aef69e547 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -26,9 +26,8 @@ pub use complete_generic_args::{ }; pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infer_self_type}; pub(in crate::semantic::generic) use inference_widening::{ - TplCandidateSource, finalize_inferred_tpl_candidate, + is_primitive_or_literal_type, regularize_tpl_candidate_type, widen_tpl_candidate_type, }; -pub use inference_widening::{WideningContext, WideningGuard, widen_type_with_context}; use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_inner; pub use instantiate_special_generic::get_keyof_members; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 582e1e9ee..5cd76d080 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -1,36 +1,237 @@ mod call_constraint; +mod inference; mod instantiate_type; mod test; -mod tpl_context; -mod tpl_pattern; mod type_substitutor; +use std::sync::Arc; + pub use call_constraint::{ CallConstraintArg, CallConstraintContext, build_call_constraint_context, normalize_constraint_type, }; use emmylua_parser::LuaAstNode; use emmylua_parser::LuaExpr; +use hashbrown::HashSet; +pub(in crate::semantic::generic) use inference::{ + InferenceContext, InferencePriority, InferenceVariance, infer_type_list, infer_types_from_expr, + multi_param_infer_multi_return, return_type_infer_types, variadic_infer_types, +}; pub use instantiate_type::*; use rowan::NodeOrToken; -pub use tpl_context::TplContext; -pub use tpl_pattern::tpl_pattern_match_args; -pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; pub use type_substitutor::TypeSubstitutor; use crate::DbIndex; +use crate::GenericTpl; use crate::GenericTplId; +use crate::InferFailReason; use crate::LuaDeclExtra; +use crate::LuaFunctionType; use crate::LuaInferCache; use crate::LuaMemberOwner; use crate::LuaSemanticDeclId; use crate::LuaType; +use crate::LuaTypeNode; use crate::SemanticDeclLevel; use crate::TypeOps; use crate::infer_node_semantic_decl; use crate::semantic::semantic_info::infer_token_semantic_decl; pub use instantiate_type::get_keyof_members; +pub fn instantiate_doc_function_by_arg_types( + db: &DbIndex, + cache: &mut LuaInferCache, + doc_function: &Arc, + call_arg_types: &[LuaType], +) -> Result, InferFailReason> { + let generic_tpl_ids = collect_doc_function_tpl_ids(doc_function); + if generic_tpl_ids.is_empty() { + return Ok(doc_function.clone()); + } + + let param_types = doc_function + .get_params() + .iter() + .map(|(_, typ)| typ.clone().unwrap_or(LuaType::Unknown)) + .collect::>(); + let mut context = InferenceContext::new(db, cache, None); + context.prepare_inference_slots(generic_tpl_ids); + infer_type_list( + &mut context, + ¶m_types, + call_arg_types, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + )?; + + let mut substitutor = TypeSubstitutor::new(); + let generic_tpls = collect_doc_function_generic_tpls(doc_function); + context.bridge_to_substitutor( + &mut substitutor, + generic_tpls.iter(), + doc_function.get_ret(), + ); + + let doc_function_ty = LuaType::DocFunction(doc_function.clone()); + Ok( + match instantiate_type_generic(db, &doc_function_ty, &substitutor) { + LuaType::DocFunction(func) => func, + _ => doc_function.clone(), + }, + ) +} + +fn collect_doc_function_tpl_ids(doc_function: &LuaFunctionType) -> HashSet { + let mut generic_tpl_ids = HashSet::new(); + doc_function.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + collect_function_tpl_with_fallback_deps(&generic_tpl, &mut generic_tpl_ids); + } + LuaType::StrTplRef(str_tpl) => { + let tpl_id = str_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return; + } + + generic_tpl_ids.insert(tpl_id); + let Some(constraint) = str_tpl.get_constraint() else { + return; + }; + + let mut constraint_deps = HashSet::new(); + if collect_function_tpl_deps_from_fallback_type( + constraint, + &mut constraint_deps, + &mut HashSet::new(), + ) { + generic_tpl_ids.extend(constraint_deps); + } + } + _ => {} + }); + + generic_tpl_ids +} + +fn collect_doc_function_generic_tpls(doc_function: &LuaFunctionType) -> Vec> { + let mut generic_tpls = Vec::new(); + doc_function.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + if generic_tpl.get_tpl_id().is_func() + && !generic_tpls.iter().any(|existing: &Arc| { + existing.get_tpl_id() == generic_tpl.get_tpl_id() + }) + { + generic_tpls.push(generic_tpl.clone()); + } + } + _ => {} + }); + + generic_tpls +} + +fn collect_function_tpl_with_fallback_deps( + generic_tpl: &GenericTpl, + generic_tpl_ids: &mut HashSet, +) { + let tpl_id = generic_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return; + } + + generic_tpl_ids.insert(tpl_id); + let Some(fallback_type) = generic_tpl + .get_default_type() + .or(generic_tpl.get_constraint()) + else { + return; + }; + + let mut fallback_deps = HashSet::new(); + let mut visiting_fallbacks = HashSet::new(); + visiting_fallbacks.insert(tpl_id); + if collect_function_tpl_deps_from_fallback_type( + fallback_type, + &mut fallback_deps, + &mut visiting_fallbacks, + ) { + generic_tpl_ids.extend(fallback_deps); + } +} + +fn collect_function_tpl_deps_from_fallback_type( + ty: &LuaType, + generic_tpl_ids: &mut HashSet, + visiting_fallbacks: &mut HashSet, +) -> bool { + let mut no_fallback_cycle = + collect_function_tpl_dep_from_fallback_type(ty, generic_tpl_ids, visiting_fallbacks); + ty.visit_nested_types(&mut |ty| { + no_fallback_cycle &= + collect_function_tpl_dep_from_fallback_type(ty, generic_tpl_ids, visiting_fallbacks); + }); + no_fallback_cycle +} + +fn collect_function_tpl_dep_from_fallback_type( + ty: &LuaType, + generic_tpl_ids: &mut HashSet, + visiting_fallbacks: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + let tpl_id = generic_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return true; + } + + if !visiting_fallbacks.insert(tpl_id) { + return false; + } + + generic_tpl_ids.insert(tpl_id); + let no_fallback_cycle = match generic_tpl + .get_default_type() + .or(generic_tpl.get_constraint()) + { + Some(fallback_type) => collect_function_tpl_deps_from_fallback_type( + fallback_type, + generic_tpl_ids, + visiting_fallbacks, + ), + None => true, + }; + visiting_fallbacks.remove(&tpl_id); + no_fallback_cycle + } + LuaType::StrTplRef(str_tpl) => { + let tpl_id = str_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return true; + } + + if !visiting_fallbacks.insert(tpl_id) { + return false; + } + + generic_tpl_ids.insert(tpl_id); + let no_fallback_cycle = match str_tpl.get_constraint() { + Some(constraint) => collect_function_tpl_deps_from_fallback_type( + constraint, + generic_tpl_ids, + visiting_fallbacks, + ), + None => true, + }; + visiting_fallbacks.remove(&tpl_id); + no_fallback_cycle + } + _ => true, + } +} + pub fn get_tpl_ref_extend_type( db: &DbIndex, cache: &mut LuaInferCache, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index be0188d9a..ebc0fc52d 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -1,5 +1,6 @@ #[cfg(test)] mod test { + use super::super::instantiate_type::{regularize_tpl_candidate_type, widen_tpl_candidate_type}; use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; #[test] @@ -331,4 +332,56 @@ result = { let c_ty = ws.expr_ty("C"); assert_eq!(ws.humanize_type(c_ty), "string"); } + + #[test] + fn test_regularize_tpl_candidate_type_preserves_root_primitive_and_widens_nested_literals() { + let mut ws = VirtualWorkspace::new(); + + let root_literal = ws.ty(r#""mode""#); + let regularized_root = { + let db = ws.analysis.compilation.get_db(); + regularize_tpl_candidate_type(db, root_literal.clone()) + }; + assert_eq!(regularized_root, root_literal); + + let table = ws.expr_ty(r#"{ kind = "mode", count = 1 }"#); + let regularized_table = { + let db = ws.analysis.compilation.get_db(); + regularize_tpl_candidate_type(db, table) + }; + assert_eq!(regularized_table, ws.ty("{ kind: string, count: integer }")); + } + + #[test] + fn test_widen_tpl_candidate_type_widens_root_primitive_and_structural_literals() { + let mut ws = VirtualWorkspace::new(); + + let root_literal = ws.ty(r#""mode""#); + let widened_root = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, root_literal) + }; + assert_eq!(widened_root, LuaType::String); + + let root_union = ws.ty(r#""left" | "right""#); + let widened_root_union = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, root_union.clone()) + }; + assert_eq!(widened_root_union, root_union); + + let tuple = ws.expr_ty(r#"{ "mode", 1 }"#); + let widened_tuple = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, tuple) + }; + assert_eq!(ws.humanize_type(widened_tuple), "(string,integer)"); + + let table = ws.expr_ty(r#"{ kind = "mode", count = 1 }"#); + let widened_table = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, table) + }; + assert_eq!(widened_table, ws.ty("{ kind: string, count: integer }")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs deleted file mode 100644 index bab917412..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs +++ /dev/null @@ -1,77 +0,0 @@ -use emmylua_parser::LuaCallExpr; - -use super::instantiate_type::TplCandidateSource; -use crate::{ - DbIndex, GenericTplId, LuaInferCache, LuaType, TypeSubstitutor, - semantic::generic::type_substitutor::TplBinding, -}; - -#[derive(Debug)] -pub struct TplContext<'a> { - pub db: &'a DbIndex, - pub cache: &'a mut LuaInferCache, - pub substitutor: &'a mut TypeSubstitutor, - pub call_expr: Option, - inference_top_level: bool, -} - -impl<'a> TplContext<'a> { - pub fn new( - db: &'a DbIndex, - cache: &'a mut LuaInferCache, - substitutor: &'a mut TypeSubstitutor, - call_expr: Option, - ) -> Self { - Self { - db, - cache, - substitutor, - call_expr, - inference_top_level: true, - } - } - - pub fn with_inference_top_level( - &mut self, - top_level: bool, - f: impl FnOnce(&mut Self) -> R, - ) -> R { - let previous = self.inference_top_level; - self.inference_top_level = previous && top_level; - let result = f(self); - self.inference_top_level = previous; - result - } - - pub(in crate::semantic::generic) fn insert_type( - &mut self, - tpl_id: GenericTplId, - replace_type: LuaType, - source: TplCandidateSource, - ) { - self.substitutor.bind( - tpl_id, - TplBinding::InferredType { - ty: replace_type, - source, - top_level: self.inference_top_level, - }, - ); - } - - pub(in crate::semantic::generic) fn insert_multi_types( - &mut self, - tpl_id: GenericTplId, - types: Vec, - source: TplCandidateSource, - ) { - self.substitutor.bind( - tpl_id, - TplBinding::InferredMultiTypes { - types, - source, - top_level: self.inference_top_level, - }, - ); - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs deleted file mode 100644 index 75d498ca4..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ /dev/null @@ -1,138 +0,0 @@ -use crate::{ - InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_type_generic, - semantic::generic::tpl_pattern::{ - TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, - }, -}; - -pub fn generic_tpl_pattern_match( - context: &mut TplContext, - generic: &LuaGenericType, - target: &LuaType, -) -> TplPatternMatchResult { - generic_tpl_pattern_match_inner(context, generic, target, &InferGuard::new()) -} - -fn generic_tpl_pattern_match_inner( - context: &mut TplContext, - source_generic: &LuaGenericType, - target: &LuaType, - infer_guard: &InferGuardRef, -) -> TplPatternMatchResult { - match target { - LuaType::Generic(target_generic) => { - let base = source_generic.get_base_type_id_ref(); - let target_base = target_generic.get_base_type_id_ref(); - if base == target_base { - let params = source_generic.get_params(); - let target_params = target_generic.get_params(); - let min_len = params.len().min(target_params.len()); - for i in 0..min_len { - match (¶ms[i], &target_params[i]) { - (LuaType::Variadic(variadict), _) => { - variadic_tpl_pattern_match(context, variadict, &target_params[i..])?; - break; - } - _ => { - tpl_pattern_match(context, ¶ms[i], &target_params[i])?; - } - } - } - return Ok(()); - } - - let target_decl = context - .db - .get_type_index() - .get_type_decl(target_base) - .ok_or(InferFailReason::None)?; - if target_decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - context.db, - target_generic.get_params().clone(), - target_base.clone(), - ); - if let Some(origin_type) = - target_decl.get_alias_origin(context.db, Some(&substitutor)) - { - return generic_tpl_pattern_match_inner( - context, - source_generic, - &origin_type, - infer_guard, - ); - } - } else if let Some(super_types) = - context.db.get_type_index().get_super_types(target_base) - { - for mut super_type in super_types { - if super_type.contains_tpl_node() { - let substitutor = - TypeSubstitutor::from_type_array(target_generic.get_params().clone()); - super_type = - instantiate_type_generic(context.db, &super_type, &substitutor); - } - - generic_tpl_pattern_match_inner( - context, - source_generic, - &super_type, - &infer_guard.fork(), - )?; - } - } - } - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - infer_guard.check(type_id)?; - let type_decl = context - .db - .get_type_index() - .get_type_decl(type_id) - .ok_or(InferFailReason::None)?; - if let Some(origin_type) = type_decl.get_alias_origin(context.db, None) { - return generic_tpl_pattern_match_inner( - context, - source_generic, - &origin_type, - infer_guard, - ); - } - - for super_type in context - .db - .get_type_index() - .get_super_types(type_id) - .unwrap_or_default() - { - generic_tpl_pattern_match_inner( - context, - source_generic, - &super_type, - &infer_guard.fork(), - )?; - } - } - LuaType::Union(union_type) => { - for union_sub_type in &union_type.into_vec() { - generic_tpl_pattern_match_inner( - context, - source_generic, - union_sub_type, - &infer_guard.fork(), - )?; - } - } - _ => { - // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 - let substitutor = TypeSubstitutor::new(); - let generic_ty = LuaType::Generic(source_generic.clone().into()); - let typ = instantiate_type_generic(context.db, &generic_ty, &substitutor); - if LuaType::from(source_generic.clone()) != typ { - tpl_pattern_match(context, &typ, target)?; - } - } - } - - Ok(()) -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs deleted file mode 100644 index 0f8751ea3..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::{ - InferFailReason, LuaSignatureId, LuaType, TplContext, infer_expr, - semantic::generic::tpl_pattern::TplPatternMatchResult, -}; - -pub fn check_lambda_tpl_pattern( - context: &mut TplContext, - signature_id: LuaSignatureId, -) -> TplPatternMatchResult { - let call_expr = context.call_expr.clone().ok_or(InferFailReason::None)?; - let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; - for arg in call_arg_list.get_args() { - if let Ok(LuaType::Signature(arg_signature_id)) = - infer_expr(context.db, context.cache, arg.clone()) - && arg_signature_id == signature_id - { - return Ok(()); - } - } - - Err(InferFailReason::UnResolveSignatureReturn(signature_id)) -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs deleted file mode 100644 index 9bf830086..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ /dev/null @@ -1,1022 +0,0 @@ -mod generic_tpl_pattern; -mod lambda_tpl_pattern; - -use std::{ops::Deref, sync::Arc}; - -use emmylua_parser::LuaAstNode; -use itertools::Itertools; -use rowan::NodeOrToken; -use smol_str::SmolStr; - -use crate::{ - InferFailReason, LuaFunctionType, LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaObjectType, - LuaSemanticDeclId, LuaTupleType, LuaTypeDeclId, LuaTypeNode, LuaUnionType, SemanticDeclLevel, - VariadicType, check_type_compact, - db_index::{DbIndex, LuaGenericType, LuaType}, - infer_node_semantic_decl, - semantic::{ - generic::{ - tpl_context::TplContext, tpl_pattern::generic_tpl_pattern::generic_tpl_pattern_match, - type_substitutor::SubstitutorValue, - }, - member::{find_index_operations, get_member_map}, - }, -}; - -use super::{ - instantiate_type::TplCandidateSource::{ConstPreserving, Plain}, - type_substitutor::{TplBinding, TypeSubstitutor}, -}; -use std::collections::HashMap; - -type TplPatternMatchResult = Result<(), InferFailReason>; - -pub fn tpl_pattern_match_args( - context: &mut TplContext, - func_param_types: &[LuaType], - call_arg_types: &[LuaType], -) -> TplPatternMatchResult { - tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, false) -} - -pub fn tpl_pattern_match_args_skip_unknown( - context: &mut TplContext, - func_param_types: &[LuaType], - call_arg_types: &[LuaType], -) -> TplPatternMatchResult { - tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, true) -} - -fn tpl_pattern_match_args_inner( - context: &mut TplContext, - func_param_types: &[LuaType], - call_arg_types: &[LuaType], - skip_unknown_tpl: bool, -) -> TplPatternMatchResult { - for i in 0..func_param_types.len() { - if i >= call_arg_types.len() { - break; - } - - let func_param_type = &func_param_types[i]; - let call_arg_type = &call_arg_types[i]; - - match (func_param_type, call_arg_type) { - (LuaType::Variadic(variadic), _) => { - variadic_tpl_pattern_match(context, variadic, &call_arg_types[i..])?; - break; - } - (_, LuaType::Variadic(variadic)) => { - multi_param_tpl_pattern_match_multi_return( - context, - &func_param_types[i..], - variadic, - )?; - break; - } - _ if skip_unknown_tpl - && func_param_type.contain_tpl() - && (call_arg_type.is_any() || call_arg_type.is_unknown()) => {} - _ => { - tpl_pattern_match(context, func_param_type, call_arg_type)?; - } - } - } - - Ok(()) -} - -pub fn multi_param_tpl_pattern_match_multi_return( - context: &mut TplContext, - func_param_types: &[LuaType], - multi_return: &VariadicType, -) -> TplPatternMatchResult { - match &multi_return { - VariadicType::Base(base) => { - let mut call_arg_types = Vec::new(); - for param in func_param_types { - if param.is_variadic() { - call_arg_types.push(LuaType::Variadic(multi_return.clone().into())); - break; - } else { - call_arg_types.push(base.clone()); - } - } - - tpl_pattern_match_args(context, func_param_types, &call_arg_types)?; - } - VariadicType::Multi(_) => { - let mut call_arg_types = Vec::new(); - for (i, param) in func_param_types.iter().enumerate() { - let Some(return_type) = multi_return.get_type(i) else { - break; - }; - - if param.is_variadic() { - call_arg_types.push(LuaType::Variadic( - multi_return.get_new_variadic_from(i).into(), - )); - break; - } else { - call_arg_types.push(return_type.clone()); - } - } - - tpl_pattern_match_args(context, func_param_types, &call_arg_types)?; - } - } - - Ok(()) -} - -fn get_str_tpl_infer_type(name: &str) -> LuaType { - match name { - "unknown" => LuaType::Unknown, - "never" => LuaType::Never, - "nil" | "void" => LuaType::Nil, - "any" => LuaType::Any, - "userdata" => LuaType::Userdata, - "thread" => LuaType::Thread, - "boolean" | "bool" => LuaType::Boolean, - "string" => LuaType::String, - "integer" | "int" => LuaType::Integer, - "number" => LuaType::Number, - "io" => LuaType::Io, - "self" => LuaType::SelfInfer, - "global" => LuaType::Global, - "function" => LuaType::Function, - _ => LuaType::Ref(LuaTypeDeclId::global(&name)), - } -} - -pub fn tpl_pattern_match( - context: &mut TplContext, - pattern: &LuaType, - target: &LuaType, -) -> TplPatternMatchResult { - let target = escape_alias(context.db, target); - if !pattern.contains_tpl_node() { - return Ok(()); - } - - match pattern { - LuaType::TplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context.insert_type(tpl.get_tpl_id(), target.clone(), Plain); - } - } - LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context.insert_type(tpl.get_tpl_id(), target, ConstPreserving); - } - } - LuaType::StrTplRef(str_tpl) => { - if let LuaType::StringConst(s) = target { - let prefix = str_tpl.get_prefix(); - let suffix = str_tpl.get_suffix(); - let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.insert_type( - str_tpl.get_tpl_id(), - get_str_tpl_infer_type(&type_name), - Plain, - ); - } - } - LuaType::Array(array_type) => { - context.with_inference_top_level(false, |context| { - array_tpl_pattern_match(context, array_type.get_base(), &target) - })?; - } - LuaType::TableGeneric(table_generic_params) => { - context.with_inference_top_level(false, |context| { - table_generic_tpl_pattern_match(context, table_generic_params, &target) - })?; - } - LuaType::Generic(generic) => { - context.with_inference_top_level(false, |context| { - generic_tpl_pattern_match(context, generic, &target) - })?; - } - LuaType::Union(union) => { - union_tpl_pattern_match(context, union, &target)?; - } - LuaType::DocFunction(doc_func) => { - context.with_inference_top_level(false, |context| { - func_tpl_pattern_match(context, doc_func, &target) - })?; - } - LuaType::Tuple(tuple) => { - context.with_inference_top_level(false, |context| { - tuple_tpl_pattern_match(context, tuple, &target) - })?; - } - LuaType::Object(obj) => { - context.with_inference_top_level(false, |context| { - object_tpl_pattern_match(context, obj, &target) - })?; - } - _ => {} - } - - Ok(()) -} - -fn object_tpl_pattern_match( - context: &mut TplContext, - origin_obj: &LuaObjectType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::Object(target_object) => { - // 先匹配 fields - for (k, v) in origin_obj.get_fields().iter().sorted_by_key(|(k, _)| *k) { - let target_value = target_object.get_fields().get(k); - if let Some(target_value) = target_value { - tpl_pattern_match(context, v, target_value)?; - } - } - // 再匹配索引访问 - let target_index_access = target_object.get_index_access(); - for (origin_key, v) in origin_obj.get_index_access() { - // 先匹配 key 类型进行转换 - let target_access = target_index_access.iter().find(|(target_key, _)| { - check_type_compact(context.db, origin_key, target_key).is_ok() - }); - if let Some(target_access) = target_access { - tpl_pattern_match(context, origin_key, &target_access.0)?; - tpl_pattern_match(context, v, &target_access.1)?; - } - } - } - LuaType::TableConst(inst) => { - let owner = LuaMemberOwner::Element(inst.clone()); - object_tpl_pattern_match_member_owner_match(context, origin_obj, owner)?; - } - _ => {} - } - - Ok(()) -} - -fn object_tpl_pattern_match_member_owner_match( - context: &mut TplContext, - object: &LuaObjectType, - owner: LuaMemberOwner, -) -> TplPatternMatchResult { - let owner_type = match &owner { - LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), - LuaMemberOwner::Type(type_id) => LuaType::Ref(type_id.clone()), - _ => { - return Err(InferFailReason::None); - } - }; - - let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; - for (k, v) in members { - let resolve_key = match &k { - LuaMemberKey::Integer(i) => Some(LuaType::IntegerConst(*i)), - LuaMemberKey::Name(s) => Some(LuaType::StringConst(s.clone().into())), - _ => None, - }; - let resolve_type = match v.len() { - 0 => LuaType::Any, - 1 => v[0].typ.clone(), - _ => { - let mut types = Vec::new(); - for m in &v { - types.push(m.typ.clone()); - } - LuaType::from_vec(types) - } - }; - - // this is a workaround, I need refactor infer member map - if resolve_type.is_unknown() - && !v.is_empty() - && let Some(LuaSemanticDeclId::Member(member_id)) = &v[0].property_owner_id - { - return Err(InferFailReason::UnResolveMemberType(*member_id)); - } - - if let Some(_) = resolve_key - && let Some(field_value) = object.get_field(&k) - { - tpl_pattern_match(context, field_value, &resolve_type)?; - } - } - - Ok(()) -} - -fn array_tpl_pattern_match( - context: &mut TplContext, - base: &LuaType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::Array(target_array_type) => { - tpl_pattern_match(context, base, target_array_type.get_base())?; - } - LuaType::Tuple(target_tuple) => { - let target_base = target_tuple.cast_down_array_base(context.db); - tpl_pattern_match(context, base, &target_base)?; - } - LuaType::Object(target_object) => { - let target_base = target_object - .cast_down_array_base(context.db) - .ok_or(InferFailReason::None)?; - tpl_pattern_match(context, base, &target_base)?; - } - _ => {} - } - - Ok(()) -} - -fn table_generic_tpl_pattern_match( - context: &mut TplContext, - table_generic_params: &[LuaType], - target: &LuaType, -) -> TplPatternMatchResult { - if table_generic_params.len() != 2 { - return Err(InferFailReason::None); - } - - match target { - LuaType::TableGeneric(target_table_generic_params) => { - let min_len = table_generic_params - .len() - .min(target_table_generic_params.len()); - for i in 0..min_len { - tpl_pattern_match( - context, - &table_generic_params[i], - &target_table_generic_params[i], - )?; - } - } - LuaType::Array(target_array_base) => { - tpl_pattern_match(context, &table_generic_params[0], &LuaType::Integer)?; - tpl_pattern_match( - context, - &table_generic_params[1], - target_array_base.get_base(), - )?; - } - LuaType::Tuple(target_tuple) => { - let len = target_tuple.get_types().len(); - let mut keys = Vec::new(); - for i in 0..len { - keys.push(LuaType::IntegerConst((i as i64) + 1)); - } - - let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); - let target_base = target_tuple.cast_down_array_base(context.db); - tpl_pattern_match(context, &table_generic_params[0], &key_type)?; - tpl_pattern_match(context, &table_generic_params[1], &target_base)?; - } - LuaType::TableConst(inst) => { - let owner = LuaMemberOwner::Element(inst.clone()); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - &[], - )?; - } - LuaType::Ref(type_id) => { - let owner = LuaMemberOwner::Type(type_id.clone()); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - &[], - )?; - } - LuaType::Def(type_id) => { - let owner = LuaMemberOwner::Type(type_id.clone()); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - &[], - )?; - } - LuaType::Generic(generic) => { - let owner = LuaMemberOwner::Type(generic.get_base_type_id()); - let target_params = generic.get_params(); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - target_params, - )?; - } - LuaType::Object(obj) => { - let mut keys = Vec::new(); - let mut values = Vec::new(); - for (k, v) in obj.get_fields() { - match k { - LuaMemberKey::Integer(i) => { - keys.push(LuaType::IntegerConst(*i)); - } - LuaMemberKey::Name(s) => { - keys.push(LuaType::StringConst(s.clone().into())); - } - _ => {} - }; - values.push(v.clone()); - } - for (k, v) in obj.get_index_access() { - keys.push(k.clone()); - values.push(v.clone()); - } - - let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); - let value_type = LuaType::Union(LuaUnionType::from_vec(values).into()); - tpl_pattern_match(context, &table_generic_params[0], &key_type)?; - tpl_pattern_match(context, &table_generic_params[1], &value_type)?; - } - - LuaType::Global | LuaType::Any | LuaType::Table | LuaType::Userdata => { - // too many - tpl_pattern_match(context, &table_generic_params[0], &LuaType::Any)?; - tpl_pattern_match(context, &table_generic_params[1], &LuaType::Any)?; - } - _ => {} - } - - Ok(()) -} - -// KV 表匹配 ref/def/tableconst -fn table_generic_tpl_pattern_member_owner_match( - context: &mut TplContext, - table_generic_params: &[LuaType], - owner: LuaMemberOwner, - target_params: &[LuaType], -) -> TplPatternMatchResult { - if table_generic_params.len() != 2 { - return Err(InferFailReason::None); - } - - let owner_type = match &owner { - LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), - LuaMemberOwner::Type(type_id) => match target_params.len() { - 0 => LuaType::Ref(type_id.clone()), - _ => LuaType::Generic(Arc::new(LuaGenericType::new( - type_id.clone(), - target_params.to_vec(), - ))), - }, - _ => { - return Err(InferFailReason::None); - } - }; - - let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; - // 如果是 pairs 调用, 我们需要尝试寻找元方法, 但目前`__pairs` 被放进成员表中 - if is_pairs_call(context).unwrap_or(false) - && try_handle_pairs_metamethod(context, table_generic_params, &members).is_ok() - { - return Ok(()); - } - - let target_key_type = table_generic_params[0].clone(); - let mut keys = Vec::new(); - let mut values = Vec::new(); - for (k, v) in members { - let key_type = match k { - LuaMemberKey::Integer(i) => LuaType::IntegerConst(i), - LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), - LuaMemberKey::ExprType(typ) => typ, - _ => continue, - }; - - if !target_key_type.is_generic() - && check_type_compact(context.db, &target_key_type, &key_type).is_err() - { - continue; - } - - keys.push(key_type); - - let resolve_type = match v.len() { - 0 => LuaType::Any, - 1 => v[0].typ.clone(), - _ => { - let mut types = Vec::new(); - for m in v { - types.push(m.typ.clone()); - } - LuaType::from_vec(types) - } - }; - - values.push(resolve_type); - } - - if keys.is_empty() { - find_index_operations(context.db, &owner_type) - .ok_or(InferFailReason::None)? - .iter() - .for_each(|m| { - if target_key_type.is_generic() { - return; - } - let key_type = match &m.key { - LuaMemberKey::ExprType(typ) => typ.clone(), - _ => return, - }; - if check_type_compact(context.db, &target_key_type, &key_type).is_ok() { - keys.push(key_type); - values.push(m.typ.clone()); - } - }); - } - - let key_type = match &keys[..] { - [] => return Err(InferFailReason::None), - [first] => first.clone(), - _ => LuaType::Union(LuaUnionType::from_vec(keys).into()), - }; - let value_type = match &values[..] { - [first] => first.clone(), - _ => LuaType::Union(LuaUnionType::from_vec(values).into()), - }; - - tpl_pattern_match(context, &table_generic_params[0], &key_type)?; - tpl_pattern_match(context, &table_generic_params[1], &value_type)?; - - Ok(()) -} - -fn union_tpl_pattern_match( - context: &mut TplContext, - union: &LuaUnionType, - target: &LuaType, -) -> TplPatternMatchResult { - let mut error_count = 0; - let mut last_error = InferFailReason::None; - for u in union.into_vec() { - match tpl_pattern_match(context, &u, target) { - // 返回 ok 时并不一定匹配成功, 仅表示没有发生错误 - Ok(_) => {} - Err(e) => { - error_count += 1; - last_error = e; - } - } - } - - if error_count == union.into_vec().len() { - Err(last_error) - } else { - Ok(()) - } -} - -fn func_tpl_pattern_match( - context: &mut TplContext, - tpl_func: &LuaFunctionType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::DocFunction(target_doc_func) => { - func_tpl_pattern_match_doc_func(context, tpl_func, target_doc_func)?; - } - LuaType::Signature(signature_id) => { - let signature = context - .db - .get_signature_index() - .get(signature_id) - .ok_or(InferFailReason::None)?; - if !signature.is_resolve_return() { - return lambda_tpl_pattern::check_lambda_tpl_pattern(context, *signature_id); - } - let fake_doc_func = signature.to_doc_func_type(); - func_tpl_pattern_match_doc_func(context, tpl_func, &fake_doc_func)?; - } - _ => {} - } - - Ok(()) -} - -fn func_tpl_pattern_match_doc_func( - context: &mut TplContext, - tpl_func: &LuaFunctionType, - target_func: &LuaFunctionType, -) -> TplPatternMatchResult { - let mut tpl_func_params = tpl_func.get_params().to_vec(); - if tpl_func.is_colon_define() { - tpl_func_params.insert(0, ("self".to_string(), Some(LuaType::Any))); - } - - let mut target_func_params = target_func.get_params().to_vec(); - - if target_func.is_colon_define() { - target_func_params.insert(0, ("self".to_string(), Some(LuaType::Any))); - } - - param_type_list_pattern_match_type_list(context, &tpl_func_params, &target_func_params)?; - - let tpl_return = tpl_func.get_ret(); - let target_return = target_func.get_ret(); - return_type_pattern_match_target_type(context, tpl_return, target_return)?; - - Ok(()) -} - -fn param_type_list_pattern_match_type_list( - context: &mut TplContext, - sources: &[(String, Option)], - targets: &[(String, Option)], -) -> TplPatternMatchResult { - let type_len = sources.len(); - let mut target_offset = 0; - for i in 0..type_len { - let source = match sources.get(i) { - Some(t) => t.1.clone().unwrap_or(LuaType::Any), - None => break, - }; - - match &source { - LuaType::Variadic(inner) => { - let i = i + target_offset; - if i >= targets.len() { - if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { - let tpl_id = tpl_ref.get_tpl_id(); - context.insert_type(tpl_id, LuaType::Nil, Plain); - } - break; - } - - if let VariadicType::Base(LuaType::TplRef(generic_tpl)) = inner.deref() { - let tpl_id = generic_tpl.get_tpl_id(); - if let Some(inferred_type_value) = context.substitutor.get(tpl_id) { - match inferred_type_value { - SubstitutorValue::Type { .. } => { - continue; - } - SubstitutorValue::MultiTypes { values, .. } => { - if values.len() > 1 { - target_offset += values.len() - 1; - } - continue; - } - SubstitutorValue::Params(params) => { - if params.len() > 1 { - target_offset += params.len() - 1; - } - continue; - } - _ => {} - } - } - } - - let mut target_rest_params = &targets[i..]; - // If the variadic parameter is not the last one, then target_rest_params should exclude the parameters that come after it. - if i + 1 < type_len { - let source_rest_len = type_len - i - 1; - if source_rest_len >= target_rest_params.len() { - continue; - } - let target_rest_len = target_rest_params.len() - source_rest_len; - target_rest_params = &target_rest_params[..target_rest_len]; - if target_rest_len > 1 { - target_offset += target_rest_len - 1; - } - } - - func_varargs_tpl_pattern_match(inner, target_rest_params, context.substitutor)?; - } - _ => { - let target = match targets.get(i + target_offset) { - Some(t) => t.1.clone().unwrap_or(LuaType::Any), - None => break, - }; - tpl_pattern_match(context, &source, &target)?; - } - } - } - - Ok(()) -} - -pub(crate) fn return_type_pattern_match_target_type( - context: &mut TplContext, - source: &LuaType, - target: &LuaType, -) -> TplPatternMatchResult { - match (source, target) { - // toooooo complex - (LuaType::Variadic(variadic_source), LuaType::Variadic(variadic_target)) => { - match variadic_target.deref() { - VariadicType::Base(target_base) => match variadic_source.deref() { - VariadicType::Base(source_base) => { - if let LuaType::TplRef(type_ref) = source_base { - let tpl_id = type_ref.get_tpl_id(); - context.insert_type(tpl_id, target_base.clone(), Plain); - } - } - VariadicType::Multi(source_multi) => { - for ret_type in source_multi { - match ret_type { - LuaType::Variadic(inner) => { - if let VariadicType::Base(base) = inner.deref() - && let LuaType::TplRef(type_ref) = base - { - let tpl_id = type_ref.get_tpl_id(); - context.insert_type(tpl_id, target_base.clone(), Plain); - } - - break; - } - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - context.insert_type(tpl_id, target_base.clone(), Plain); - } - _ => {} - } - } - } - }, - VariadicType::Multi(target_types) => { - variadic_tpl_pattern_match(context, variadic_source, target_types)?; - } - } - } - (LuaType::Variadic(variadic), _) => { - variadic_tpl_pattern_match(context, variadic, std::slice::from_ref(target))?; - } - (_, LuaType::Variadic(variadic)) => { - multi_param_tpl_pattern_match_multi_return( - context, - std::slice::from_ref(source), - variadic, - )?; - } - _ => { - tpl_pattern_match(context, source, target)?; - } - } - - Ok(()) -} - -fn func_varargs_tpl_pattern_match( - variadic: &VariadicType, - target_rest_params: &[(String, Option)], - substitutor: &mut TypeSubstitutor, -) -> TplPatternMatchResult { - match variadic { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl_ref) = base { - let tpl_id = tpl_ref.get_tpl_id(); - substitutor.bind( - tpl_id, - TplBinding::VariadicParams( - target_rest_params - .iter() - .map(|(n, t)| (n.clone(), t.clone())) - .collect(), - ), - ); - } - } - VariadicType::Multi(_) => {} - } - - Ok(()) -} - -pub fn variadic_tpl_pattern_match( - context: &mut TplContext, - tpl: &VariadicType, - target_rest_types: &[LuaType], -) -> TplPatternMatchResult { - match tpl { - VariadicType::Base(base) => match base { - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.insert_type(tpl_id, LuaType::Nil, Plain); - } - 1 => { - // If the single argument is itself a multi-return (e.g. a function call - // returning multiple values), expand it so that `T...` receives all the - // return values rather than a single Variadic wrapper. - match &target_rest_types[0] { - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Multi(types) => match types.len() { - 0 => { - context.insert_type(tpl_id, LuaType::Nil, Plain); - } - 1 => { - context.insert_type(tpl_id, types[0].clone(), Plain); - } - _ => { - context.insert_multi_types(tpl_id, types.to_vec(), Plain); - } - }, - VariadicType::Base(base) => { - context - .substitutor - .bind(tpl_id, TplBinding::VariadicBase(base.clone())); - } - }, - arg => { - context.insert_type(tpl_id, arg.clone(), Plain); - } - } - } - _ => { - context.insert_multi_types(tpl_id, target_rest_types.to_vec(), Plain); - } - } - } - LuaType::ConstTplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.insert_type(tpl_id, LuaType::Nil, ConstPreserving); - } - 1 => { - context.insert_type(tpl_id, target_rest_types[0].clone(), ConstPreserving); - } - _ => { - context.insert_multi_types( - tpl_id, - target_rest_types.to_vec(), - ConstPreserving, - ); - } - } - } - _ => {} - }, - VariadicType::Multi(multi) => { - for (i, ret_type) in multi.iter().enumerate() { - match ret_type { - LuaType::Variadic(inner) => { - if i < target_rest_types.len() { - variadic_tpl_pattern_match(context, inner, &target_rest_types[i..])?; - } - - break; - } - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.get(i) { - Some(t) => { - context.insert_type(tpl_id, t.clone(), Plain); - } - None => { - break; - } - }; - } - _ => {} - } - } - } - } - - Ok(()) -} - -fn tuple_tpl_pattern_match( - context: &mut TplContext, - tpl_tuple: &LuaTupleType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::Tuple(target_tuple) => { - let tpl_tuple_types = tpl_tuple.get_types(); - let target_tuple_types = target_tuple.get_types(); - let tpl_tuple_len = tpl_tuple_types.len(); - for i in 0..tpl_tuple_len { - let tpl_type = &tpl_tuple_types[i]; - - if let LuaType::Variadic(inner) = tpl_type { - let target_rest_types = &target_tuple_types[i..]; - variadic_tpl_pattern_match(context, inner, target_rest_types)?; - break; - } - - let target_type = match target_tuple_types.get(i) { - Some(t) => t, - None => break, - }; - - tpl_pattern_match(context, tpl_type, target_type)?; - } - } - LuaType::Array(target_array_base) => { - let tupl_tuple_types = tpl_tuple.get_types(); - let last_type = tupl_tuple_types.last().ok_or(InferFailReason::None)?; - if let LuaType::Variadic(inner) = last_type { - match inner.deref() { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl_ref) = base { - let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.bind( - tpl_id, - TplBinding::VariadicBase(target_array_base.get_base().clone()), - ); - } - } - VariadicType::Multi(_) => {} - } - } - } - _ => {} - } - - Ok(()) -} - -fn escape_alias(db: &DbIndex, may_alias: &LuaType) -> LuaType { - if let LuaType::Ref(type_id) = may_alias - && let Some(type_decl) = db.get_type_index().get_type_decl(type_id) - && type_decl.is_alias() - && let Some(origin_type) = type_decl.get_alias_origin(db, None) - { - return origin_type.clone(); - } - - may_alias.clone() -} - -fn is_pairs_call(context: &mut TplContext) -> Option { - let call_expr = context.call_expr.as_ref()?; - let prefix_expr = call_expr.get_prefix_expr()?; - let semantic_decl = match prefix_expr.syntax().clone().into() { - NodeOrToken::Node(node) => infer_node_semantic_decl( - context.db, - context.cache, - node, - SemanticDeclLevel::default(), - ), - _ => None, - }?; - - let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl else { - return None; - }; - let decl = context.db.get_decl_index().get_decl(&decl_id)?; - if !context.db.get_module_index().is_std(&decl.get_file_id()) { - return None; - } - let name = decl.get_name(); - if name != "pairs" { - return None; - } - Some(true) -} - -fn try_handle_pairs_metamethod( - context: &mut TplContext, - table_generic_params: &[LuaType], - members: &HashMap>, -) -> TplPatternMatchResult { - let pairs_member = members - .get(&LuaMemberKey::Name("__pairs".into())) - .ok_or(InferFailReason::None)? - .first() - .ok_or(InferFailReason::None)?; - // 获取迭代函数返回类型 - let meta_return = match &pairs_member.typ { - LuaType::Signature(signature_id) => context - .db - .get_signature_index() - .get(signature_id) - .map(|s| s.get_return_type()), - LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), - _ => None, - } - .ok_or(InferFailReason::None)?; - - // 解析出迭代函数返回类型 - let final_return_type = match meta_return { - LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), - LuaType::Signature(signature_id) => context - .db - .get_signature_index() - .get(&signature_id) - .map(|s| s.get_return_type()), - _ => None, - }; - - if let Some(LuaType::Variadic(variadic)) = &final_return_type { - let key_type = variadic.get_type(0).ok_or(InferFailReason::None)?; - let value_type = variadic.get_type(1).ok_or(InferFailReason::None)?; - tpl_pattern_match(context, &table_generic_params[0], key_type)?; - tpl_pattern_match(context, &table_generic_params[1], value_type)?; - return Ok(()); - } - Err(InferFailReason::None) -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 6dbf090da..b212f8fc7 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -1,7 +1,6 @@ use hashbrown::{HashMap, HashSet}; -use super::instantiate_type::{TplCandidateSource, finalize_inferred_tpl_candidate}; -use crate::{DbIndex, GenericTpl, GenericTplId, LuaType, LuaTypeDeclId}; +use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId}; use std::sync::Arc; const MAX_INSTANTIATION_DEPTH: usize = 128; @@ -17,19 +16,10 @@ pub(super) enum UninferredTplPolicy { pub(in crate::semantic::generic) enum TplBinding { FinalizedType(LuaType), - InferredType { - ty: LuaType, - source: TplCandidateSource, - top_level: bool, - }, ReplaceConstType(LuaType), ConditionalInferType(LuaType), VariadicParams(Vec<(String, Option)>), - InferredMultiTypes { - types: Vec, - source: TplCandidateSource, - top_level: bool, - }, + InferredMultiTypes(Vec), VariadicBase(LuaType), } @@ -143,9 +133,7 @@ impl TypeSubstitutor { tpl_replace_map.insert( GenericTplId::Type(i as u32), SubstitutorValue::Type { - value: SubstitutorTypeValue::new(ty, true), - source: TplCandidateSource::Finalized, - top_level: true, + value: SubstitutorTypeValue::new(ty), }, ); } @@ -172,9 +160,7 @@ impl TypeSubstitutor { tpl_replace_map.insert( tpl_id, SubstitutorValue::Type { - value: SubstitutorTypeValue::new(ty, true), - source: TplCandidateSource::Finalized, - top_level: true, + value: SubstitutorTypeValue::new(ty), }, ); } @@ -220,9 +206,7 @@ impl TypeSubstitutor { self.tpl_replace_map.insert( tpl_id, SubstitutorValue::Type { - value: SubstitutorTypeValue::new(replace_type, false), - source: TplCandidateSource::ConstPreserving, - top_level: true, + value: SubstitutorTypeValue::new(replace_type), }, ); } @@ -234,9 +218,7 @@ impl TypeSubstitutor { self.tpl_replace_map.insert( tpl_id, SubstitutorValue::Type { - value: SubstitutorTypeValue::new(replace_type, false), - source: TplCandidateSource::ConstPreserving, - top_level: true, + value: SubstitutorTypeValue::new(replace_type), }, ); } @@ -248,18 +230,7 @@ impl TypeSubstitutor { let value = match binding { TplBinding::FinalizedType(replace_type) => SubstitutorValue::Type { - value: SubstitutorTypeValue::new(replace_type, true), - source: TplCandidateSource::Finalized, - top_level: true, - }, - TplBinding::InferredType { - ty, - source, - top_level, - } => SubstitutorValue::Type { - value: SubstitutorTypeValue::new(ty, false), - source, - top_level, + value: SubstitutorTypeValue::new(replace_type), }, TplBinding::VariadicParams(params) => { let params = params @@ -268,22 +239,8 @@ impl TypeSubstitutor { .collect(); SubstitutorValue::Params(params) } - TplBinding::InferredMultiTypes { - types, - source, - top_level, - } => SubstitutorValue::MultiTypes { - values: types - .into_iter() - .map(|ty| { - SubstitutorTypeValue::new( - ty, - source == TplCandidateSource::Finalized, - ) - }) - .collect(), - source, - top_level, + TplBinding::InferredMultiTypes(types) => SubstitutorValue::MultiTypes { + values: types.into_iter().map(SubstitutorTypeValue::new).collect(), }, TplBinding::VariadicBase(type_base) => SubstitutorValue::MultiBase(type_base), TplBinding::ReplaceConstType(_) | TplBinding::ConditionalInferType(_) => { @@ -315,79 +272,6 @@ impl TypeSubstitutor { } } - pub(super) fn finalize_inferred_types<'a>( - &mut self, - db: &DbIndex, - generic_tpls: impl IntoIterator>, - return_type: &LuaType, - ) { - for tpl in generic_tpls { - let tpl_id = tpl.get_tpl_id(); - let return_top_level = is_tpl_at_top_level(db, return_type, tpl_id); - let Some(value) = self.tpl_replace_map.get(&tpl_id) else { - continue; - }; - - let finalized_value = match value { - SubstitutorValue::Type { - value, - source, - top_level, - } => { - if value.is_finalized() { - None - } else { - Some(SubstitutorValue::Type { - value: value.finalized( - db, - tpl.as_ref(), - *source, - *top_level, - return_top_level, - self, - ), - source: TplCandidateSource::Finalized, - top_level: true, - }) - } - } - SubstitutorValue::MultiTypes { - values, - source, - top_level, - } => { - if *source == TplCandidateSource::Finalized { - None - } else { - let values = values - .iter() - .map(|value| { - value.finalized( - db, - tpl.as_ref(), - *source, - *top_level, - return_top_level, - self, - ) - }) - .collect(); - Some(SubstitutorValue::MultiTypes { - values, - source: TplCandidateSource::Finalized, - top_level: true, - }) - } - } - _ => None, - }; - - if let Some(finalized_value) = finalized_value { - self.tpl_replace_map.insert(tpl_id, finalized_value); - } - } - } - pub fn check_recursion(&self, type_id: &LuaTypeDeclId) -> bool { if let Some(alias_type_id) = &self.alias_type_id && alias_type_id == type_id @@ -410,14 +294,13 @@ impl TypeSubstitutor { #[derive(Debug, Clone)] pub struct SubstitutorTypeValue { raw: LuaType, - finalized: Option, } impl SubstitutorTypeValue { - fn new(raw: LuaType, already_finalized: bool) -> Self { - let raw = into_ref_type(raw); - let finalized = already_finalized.then(|| raw.clone()); - Self { raw, finalized } + fn new(raw: LuaType) -> Self { + Self { + raw: into_ref_type(raw), + } } pub fn raw(&self) -> &LuaType { @@ -425,52 +308,16 @@ impl SubstitutorTypeValue { } pub(super) fn resolved(&self) -> &LuaType { - self.finalized.as_ref().unwrap_or(&self.raw) - } - - fn is_finalized(&self) -> bool { - self.finalized.is_some() - } - - fn finalized( - &self, - db: &DbIndex, - tpl: &GenericTpl, - source: TplCandidateSource, - top_level: bool, - return_top_level: bool, - substitutor: &TypeSubstitutor, - ) -> Self { - let finalized = finalize_inferred_tpl_candidate( - db, - tpl, - &self.raw, - source, - top_level, - return_top_level, - substitutor, - ); - Self { - raw: self.raw.clone(), - finalized: Some(finalized), - } + &self.raw } } #[derive(Debug, Clone)] pub(super) enum SubstitutorValue { None, - Type { - value: SubstitutorTypeValue, - source: TplCandidateSource, - top_level: bool, - }, + Type { value: SubstitutorTypeValue }, Params(Vec<(String, Option)>), - MultiTypes { - values: Vec, - source: TplCandidateSource, - top_level: bool, - }, + MultiTypes { values: Vec }, MultiBase(LuaType), } @@ -480,79 +327,6 @@ impl SubstitutorValue { } } -fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { - is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) -} - -fn is_tpl_at_top_level_with_guard( - db: &DbIndex, - ty: &LuaType, - tpl_id: GenericTplId, - visited_aliases: &mut HashSet, -) -> bool { - match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, - LuaType::Union(union) => union.into_vec().iter().any(|member| { - let mut branch_aliases = visited_aliases.clone(); - is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) - }), - LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { - let mut branch_aliases = visited_aliases.clone(); - is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) - }), - LuaType::Generic(generic) => { - let type_decl_id = generic.get_base_type_id_ref(); - let Some(alias_param) = - get_transparent_alias_param_index(db, type_decl_id, visited_aliases) - else { - return false; - }; - - generic.get_params().get(alias_param).is_some_and(|param| { - is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) - }) - } - _ => false, - } -} - -fn get_transparent_alias_param_index( - db: &DbIndex, - type_decl_id: &LuaTypeDeclId, - visited_aliases: &mut HashSet, -) -> Option { - if !visited_aliases.insert(type_decl_id.clone()) { - return None; - } - - let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; - if !type_decl.is_alias() { - return None; - }; - let origin = type_decl.get_alias_ref()?; - - match origin { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => - { - Some(tpl.get_tpl_id().get_idx()) - } - LuaType::Generic(generic) => { - get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) - .and_then(|alias_param| generic.get_params().get(alias_param)) - .and_then(|param| match param { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => - { - Some(tpl.get_tpl_id().get_idx()) - } - _ => None, - }) - } - _ => None, - } -} - fn into_ref_type(ty: LuaType) -> LuaType { match ty { LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index a3cb628af..3c01665f1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -899,7 +899,7 @@ mod tests { ); assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); - assert_eq!(ws.expr_ty("payload"), ws.ty("string")); + assert_eq!(ws.expr_ty("payload"), LuaType::Unknown); } #[test] From 58504d5edc1bd53c8d8f10e46f45cfc1a1105a79 Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Thu, 21 May 2026 20:34:28 +0800 Subject: [PATCH 07/10] refactor(generic): use TypeMapper --- .../analyzer/unresolve/find_decl_function.rs | 24 +- .../src/compilation/test/generic_test.rs | 37 +- .../src/db_index/type/humanize_type.rs | 7 +- .../src/db_index/type/type_decl.rs | 14 +- .../generic/generic_constraint_mismatch.rs | 23 +- .../src/semantic/generic/call_constraint.rs | 87 ++- .../src/semantic/generic/inference/context.rs | 460 ++++++++++++ .../infer_types.rs} | 697 +----------------- .../src/semantic/generic/inference/mod.rs | 247 +++++++ .../src/semantic/generic/inference/resolve.rs | 135 ++++ .../src/semantic/generic/inference/tests.rs | 90 +++ .../instantiate_type/complete_generic_args.rs | 21 +- .../generic/instantiate_type/context.rs | 115 +++ .../infer_call_func_generic.rs | 135 ++-- .../instantiate_conditional_generic.rs | 28 +- .../instantiate_mapped_type.rs | 9 +- .../instantiate_special_generic.rs | 41 +- .../semantic/generic/instantiate_type/mod.rs | 295 +++++--- .../src/semantic/generic/mod.rs | 14 +- .../src/semantic/generic/test.rs | 377 +++++++++- .../src/semantic/generic/type_mapper.rs | 364 +++++++++ .../src/semantic/generic/type_substitutor.rs | 335 --------- .../src/semantic/infer/infer_call/mod.rs | 25 +- .../src/semantic/infer/infer_index/mod.rs | 32 +- .../src/semantic/member/find_index.rs | 19 +- .../src/semantic/member/find_members.rs | 30 +- .../src/semantic/member/infer_raw_member.rs | 15 +- .../collect_callable_overloads.rs | 8 +- .../semantic/type_check/complex_type/mod.rs | 12 +- .../src/semantic/type_check/func_type.rs | 11 +- .../src/semantic/type_check/generic_type.rs | 20 +- .../src/semantic/type_check/simple_type.rs | 13 +- .../completion/providers/function_provider.rs | 34 +- .../src/handlers/hover/function/mod.rs | 12 +- .../src/handlers/hover/hover_builder.rs | 33 +- .../build_signature_helper.rs | 6 +- 36 files changed, 2374 insertions(+), 1451 deletions(-) create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs rename crates/emmylua_code_analysis/src/semantic/generic/{inference.rs => inference/infer_types.rs} (70%) create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs create mode 100644 crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs delete mode 100644 crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs index abe30c473..ea0a2f0ca 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs @@ -4,7 +4,7 @@ use smol_str::SmolStr; use crate::{ InFiled, InferFailReason, InferGuardRef, LuaInferCache, LuaInstanceType, LuaMemberId, - LuaMemberOwner, LuaOperatorOwner, TypeOps, TypeSubstitutor, check_type_compact, + LuaMemberOwner, LuaOperatorOwner, TypeMapper, TypeOps, check_type_compact, db_index::{ DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, @@ -355,7 +355,7 @@ fn index_generic_members_from_super_generics( db: &DbIndex, cache: &mut LuaInferCache, type_decl_id: &LuaTypeDeclId, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, index_expr: LuaIndexMemberExpr, infer_guard: &InferGuardRef, deep_guard: &mut DeepLevel, @@ -370,7 +370,7 @@ fn index_generic_members_from_super_generics( let type_decl_id = type_decl.get_id(); if let Some(super_types) = type_index.get_super_types(&type_decl_id) { super_types.iter().find_map(|super_type| { - let super_type = instantiate_type_generic(db, super_type, substitutor); + let super_type = instantiate_type_generic(db, super_type, mapper); find_function_type_by_member_key( db, cache, @@ -397,13 +397,13 @@ fn find_generic_member( let base_type = generic_type.get_base_type(); let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); if let LuaType::Ref(base_type_decl_id) = &base_type { let result = index_generic_members_from_super_generics( db, cache, base_type_decl_id, - &substitutor, + &mapper, index_expr.clone(), infer_guard, deep_guard, @@ -422,7 +422,7 @@ fn find_generic_member( deep_guard, )?; - Ok(instantiate_type_generic(db, &member_type, &substitutor)) + Ok(instantiate_type_generic(db, &member_type, &mapper)) } fn find_instance_member_decl_type( @@ -772,17 +772,17 @@ fn find_member_by_index_generic( return Err(InferFailReason::None); }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index .get_type_decl(&type_decl_id) .ok_or(InferFailReason::None)?; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&mapper)) { return find_function_type_by_operator( db, cache, - &instantiate_type_generic(db, &origin_type, &substitutor), + &instantiate_type_generic(db, &origin_type, &mapper), index_expr.clone(), &infer_guard.fork(), deep_guard, @@ -801,9 +801,9 @@ fn find_member_by_index_generic( .get_operator(index_operator_id) .ok_or(InferFailReason::None)?; let operand = index_operator.get_operand(db); - let instianted_operand = instantiate_type_generic(db, &operand, &substitutor); + let instianted_operand = instantiate_type_generic(db, &operand, &mapper); let return_type = - instantiate_type_generic(db, &index_operator.get_result(db)?, &substitutor); + instantiate_type_generic(db, &index_operator.get_result(db)?, &mapper); let result = find_index_metamethod(db, cache, &member_key, &instianted_operand, &return_type); @@ -826,7 +826,7 @@ fn find_member_by_index_generic( let result = find_function_type_by_operator( db, cache, - &instantiate_type_generic(db, &super_type, &substitutor), + &instantiate_type_generic(db, &super_type, &mapper), index_expr.clone(), &infer_guard.fork(), deep_guard, diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 94b316606..1ce5adac2 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -3,8 +3,8 @@ mod test { use emmylua_parser::LuaClosureExpr; use crate::{ - DiagnosticCode, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId, TypeSubstitutor, - VirtualWorkspace, complete_type_generic_args, instantiate_type_generic, + DiagnosticCode, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId, TypeMapper, + TypeMapperValue, VirtualWorkspace, complete_type_generic_args, instantiate_type_generic, }; #[test] @@ -533,8 +533,8 @@ mod test { ); let generic_ty = ws.ty(r#"Forward<"a" | "b">"#); - let instantiated = - instantiate_type_generic(ws.get_db_mut(), &generic_ty, &TypeSubstitutor::new()); + let mapper = TypeMapper::empty(); + let instantiated = instantiate_type_generic(ws.get_db_mut(), &generic_ty, &mapper); assert_eq!(instantiated, ws.ty(r#""a""#)); } @@ -988,6 +988,35 @@ mod test { assert_eq!(ws.humanize_type(default_type), "string"); } + #[test] + fn test_signature_instantiation_keeps_template_shapes() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@generic T + ---@param x T + ---@return T + local function id(x) + return x + end + "#, + ); + + let closure = ws.get_node::(file_id); + let signature_id = LuaSignatureId::from_closure(file_id, &closure); + let mapper = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let instantiated = instantiate_type_generic( + ws.analysis.compilation.get_db(), + &LuaType::Signature(signature_id), + &mapper, + ); + + assert_eq!(ws.humanize_type(instantiated), "fun(x: string) -> string"); + } + #[test] fn test_bare_generic_type_uses_default() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index f26732263..05b19bed3 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -6,8 +6,7 @@ use itertools::Itertools; use crate::{ AsyncState, DbIndex, LuaAliasCallType, LuaConditionalType, LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaSignatureId, - LuaStringTplType, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, TypeSubstitutor, - VariadicType, + LuaStringTplType, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, TypeMapper, VariadicType, }; use super::{LuaAliasCallKind, LuaMultiLineUnion}; @@ -717,8 +716,8 @@ impl<'a> TypeHumanizer<'a> { return Ok(()); } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().clone()); - if let Some(origin_type) = type_decl.get_alias_origin(self.db, Some(&substitutor)) { + let mapper = TypeMapper::from_type_array(generic.get_params().clone()); + if let Some(origin_type) = type_decl.get_alias_origin(self.db, Some(&mapper)) { w.write_str(" = ")?; let saved = self.level; self.level = self.child_level(); diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs index ba8ef4a05..7ca509d87 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use smol_str::SmolStr; use crate::{ - DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TypeSubstitutor, db_index::WorkspaceId, + DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TypeMapper, db_index::WorkspaceId, instantiate_type_generic, }; @@ -130,17 +130,13 @@ impl LuaTypeDecl { .map(|idx| &self.id.get_name()[..idx]) } - pub fn get_alias_origin( - &self, - db: &DbIndex, - substitutor: Option<&TypeSubstitutor>, - ) -> Option { + pub fn get_alias_origin(&self, db: &DbIndex, mapper: Option<&TypeMapper>) -> Option { match &self.extra { LuaTypeExtra::Alias { origin: Some(origin), } => { - let substitutor = match substitutor { - Some(substitutor) => substitutor, + let mapper = match mapper { + Some(mapper) => mapper, None => return Some(origin.clone()), }; @@ -153,7 +149,7 @@ impl LuaTypeDecl { return Some(origin.clone()); } - Some(instantiate_type_generic(db, origin, substitutor)) + Some(instantiate_type_generic(db, origin, mapper)) } _ => None, } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index c535f984a..ae1ba0eb8 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -14,7 +14,7 @@ use crate::{ DiagnosticCode, DocTypeInferContext, GenericTplId, LuaArrayType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, LuaTypeNode, LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, - TypeSubstitutor, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, + TypeMapper, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, }; pub struct GenericConstraintMismatchChecker; @@ -55,7 +55,7 @@ fn check_call_expr( let Some(CallConstraintContext { params, args, - substitutor, + mapper, }) = build_call_constraint_context(semantic_model, &call_expr) else { return Some(()); @@ -76,7 +76,7 @@ fn check_call_expr( param_type, &args, false, - &substitutor, + &mapper, ); } @@ -652,22 +652,21 @@ fn check_doc_type_generic_constraints( .get_generic_params(&type_id)?; let instantiate_arg = explicit_arg_instantiation_flags(&generic_params, explicit_args.len()); - let empty_substitutor = TypeSubstitutor::new(); + let empty_mapper = TypeMapper::empty(); let param_types = explicit_args .iter() .enumerate() .map(|(idx, doc_type)| { let ty = infer_doc_type(doc_ctx, doc_type); if instantiate_arg.get(idx).copied().unwrap_or(false) { - instantiate_type_generic(semantic_model.get_db(), &ty, &empty_substitutor) + instantiate_type_generic(semantic_model.get_db(), &ty, &empty_mapper) } else { ty } }) .collect::>(); - let substitutor = - TypeSubstitutor::from_alias(semantic_model.get_db(), param_types.clone(), type_id); + let mapper = TypeMapper::from_alias(semantic_model.get_db(), param_types.clone(), &type_id); for (i, param_type) in param_types.iter().enumerate() { let Some(explicit_arg) = explicit_args.get(i) else { @@ -681,7 +680,7 @@ fn check_doc_type_generic_constraints( }; let mut extend_type = - instantiate_type_generic(semantic_model.get_db(), &extend_type, &substitutor); + instantiate_type_generic(semantic_model.get_db(), &extend_type, &mapper); extend_type = normalize_keyof_any_constraint(extend_type); let result = semantic_model.type_check_detail(&extend_type, param_type); if result.is_err() { @@ -762,7 +761,7 @@ fn check_param( param_type: &LuaType, args: &[CallConstraintArg], from_union: bool, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> Option<()> { // 应该先通过泛型体操约束到唯一类型再进行检查 match param_type { @@ -770,7 +769,7 @@ fn check_param( let extend_type = str_tpl_ref.get_constraint().cloned().map(|ty| { normalize_constraint_type( semantic_model.get_db(), - instantiate_type_generic(semantic_model.get_db(), &ty, substitutor), + instantiate_type_generic(semantic_model.get_db(), &ty, mapper), ) }); let arg = args.get(param_index)?; @@ -793,7 +792,7 @@ fn check_param( let extend_type = tpl_ref.get_constraint().cloned().map(|ty| { normalize_constraint_type( semantic_model.get_db(), - instantiate_type_generic(semantic_model.get_db(), &ty, substitutor), + instantiate_type_generic(semantic_model.get_db(), &ty, mapper), ) }); let arg_type = args.get(param_index).map(|arg| &arg.check_type); @@ -812,7 +811,7 @@ fn check_param( union_member_type, args, true, - substitutor, + mapper, ); } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index e55590b35..70456257f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -1,20 +1,22 @@ use std::{ops::Deref, sync::Arc}; use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; -use hashbrown::HashSet; +use hashbrown::{HashMap, HashSet}; use rowan::TextRange; use crate::{ DbIndex, DocTypeInferContext, GenericTpl, GenericTplId, LuaFunctionType, LuaSemanticDeclId, - LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, + LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeMapper, TypeOps, VariadicType, infer_doc_type, }; +use super::TypeMapperValue; + // 泛型约束上下文 pub struct CallConstraintContext { pub params: Vec<(String, Option)>, pub args: Vec, - pub substitutor: TypeSubstitutor, + pub mapper: TypeMapper, } pub struct CallConstraintArg { @@ -30,11 +32,8 @@ pub fn build_call_constraint_context( let doc_func = infer_call_doc_function(semantic_model, call_expr)?; let mut params = doc_func.get_params().to_vec(); let mut args = get_arg_infos(semantic_model, call_expr)?; - let mut substitutor = TypeSubstitutor::new(); let generic_tpls = collect_func_tpl_ids(¶ms); - if !generic_tpls.is_empty() { - substitutor.prepare_inference_slots(generic_tpls); - } + let mut mapper_builder = CallConstraintMapperBuilder::new(generic_tpls); // 读取显式传入的泛型实参 if let Some(type_list) = call_expr.get_call_generic_type_list() { @@ -42,7 +41,7 @@ pub fn build_call_constraint_context( DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for (idx, doc_type) in type_list.get_types().enumerate() { let ty = infer_doc_type(doc_ctx, &doc_type); - substitutor.bind_type(GenericTplId::Func(idx as u32), ty); + mapper_builder.bind_type(GenericTplId::Func(idx as u32), ty); } } @@ -65,12 +64,12 @@ pub fn build_call_constraint_context( } } - collect_generic_assignments(&mut substitutor, ¶ms, &args); + collect_generic_assignments(&mut mapper_builder, ¶ms, &args); Some(CallConstraintContext { params, args, - substitutor, + mapper: mapper_builder.into_mapper(), }) } @@ -84,7 +83,7 @@ pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { // 收集各个参数对应的泛型推导 fn collect_generic_assignments( - substitutor: &mut TypeSubstitutor, + mapper_builder: &mut CallConstraintMapperBuilder, params: &[(String, Option)], args: &[CallConstraintArg], ) { @@ -95,7 +94,7 @@ fn collect_generic_assignments( let Some(arg) = args.get(idx) else { continue; }; - record_generic_assignment(param_type, &arg.check_type, substitutor); + record_generic_assignment(param_type, &arg.check_type, mapper_builder); } } @@ -256,31 +255,85 @@ fn collect_generic_tpl_from_fallback( fn record_generic_assignment( param_type: &LuaType, arg_type: &LuaType, - substitutor: &mut TypeSubstitutor, + mapper_builder: &mut CallConstraintMapperBuilder, ) { match param_type { LuaType::TplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); + mapper_builder.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::ConstTplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); + mapper_builder.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::StrTplRef(str_tpl_ref) => { - substitutor.bind_type(str_tpl_ref.get_tpl_id(), arg_type.clone()); + mapper_builder.bind_type(str_tpl_ref.get_tpl_id(), arg_type.clone()); } LuaType::Variadic(variadic) => { if let Some(inner) = variadic.get_type(0) { - record_generic_assignment(inner, arg_type, substitutor); + record_generic_assignment(inner, arg_type, mapper_builder); } } _ => {} } } +struct CallConstraintMapperBuilder { + bindings: Vec<(GenericTplId, Option)>, + binding_indices: HashMap, +} + +impl CallConstraintMapperBuilder { + fn new(generic_tpls: HashSet) -> Self { + let mut bindings = Vec::with_capacity(generic_tpls.len()); + let mut binding_indices = HashMap::with_capacity(generic_tpls.len()); + for tpl_id in generic_tpls { + binding_indices.insert(tpl_id, bindings.len()); + bindings.push((tpl_id, None)); + } + + Self { + bindings, + binding_indices, + } + } + + fn bind_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { + if tpl_id.is_conditional_infer() { + return; + } + + if let Some(index) = self.binding_indices.get(&tpl_id).copied() { + if let Some((_, existing_type)) = self.bindings.get_mut(index) + && existing_type.is_none() + { + *existing_type = Some(replace_type); + } + return; + } + + self.binding_indices.insert(tpl_id, self.bindings.len()); + self.bindings.push((tpl_id, Some(replace_type))); + } + + fn into_mapper(self) -> TypeMapper { + let (sources, targets): (Vec<_>, Vec<_>) = self + .bindings + .into_iter() + .map(|(tpl_id, ty)| { + ( + tpl_id, + ty.map(TypeMapperValue::type_value) + .unwrap_or(TypeMapperValue::None), + ) + }) + .unzip(); + TypeMapper::from_values(sources, targets) + } +} + // 解析冒号调用时调用者的具体类型 fn infer_call_source_type( semantic_model: &SemanticModel, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs new file mode 100644 index 000000000..c111c832a --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs @@ -0,0 +1,460 @@ +use std::{cell::RefCell, rc::Rc, sync::Arc}; + +use emmylua_parser::LuaCallExpr; +use hashbrown::{HashMap, HashSet}; + +use crate::{ + DbIndex, GenericTpl, GenericTplId, LuaInferCache, LuaType, check_type_compact, + instantiate_type_generic, semantic::generic::regularize_tpl_candidate_type, +}; + +use crate::semantic::generic::{TypeMapper, TypeMapperValue}; + +use super::{ + InferenceCandidate, InferenceCandidateView, InferencePriority, InferenceResult, + InferenceVariance, is_tpl_at_top_level, resolve::resolve_info, +}; + +#[derive(Debug, Clone)] +pub(super) struct InferenceRecord { + pub(super) candidate: InferenceCandidate, + pub(super) priority: InferencePriority, +} + +#[derive(Debug, Clone, Default)] +pub(super) struct InferenceInfo { + // 协变候选, 例如参数位置或返回值位置的正向推断. + pub(super) covariant: Vec, + // 逆变候选, 例如函数参数约束的反向传播. + pub(super) contravariant: Vec, + // 多返回值或变参展开时收集的候选. + pub(super) multi: Vec, + // 已经固定下来的结果, 例如显式泛型或变参参数表. + pub(super) fixed: Option, + // 最近一次写入的推断结果, 便于在后续阶段判断是否需要重算. + pub(super) inferred: Option, + // 是否已经被 fixing mapper 读取并固定. + pub(super) is_fixed: bool, + // 当前信息是否始终处于顶层位置. + pub(super) top_level: bool, + // 该模板收到的最高优先级. + pub(super) priority: Option, +} + +impl InferenceInfo { + fn new() -> Self { + Self { + top_level: true, + ..Self::default() + } + } + + fn add_candidate( + &mut self, + variance: InferenceVariance, + candidate: InferenceCandidate, + top_level: bool, + priority: InferencePriority, + ) { + if self.is_fixed { + return; + } + + self.top_level &= top_level; + self.priority = Some( + self.priority + .map_or(priority, |current| current.max(priority)), + ); + self.inferred = None; + let record = InferenceRecord { + candidate, + priority, + }; + match variance { + InferenceVariance::Covariant => self.covariant.push(record), + InferenceVariance::Contravariant => self.contravariant.push(record), + } + } +} + +#[derive(Debug)] +pub(in crate::semantic::generic) struct InferenceContext<'a> { + pub db: &'a DbIndex, + pub cache: &'a mut LuaInferCache, + pub call_expr: Option, + infos: HashMap, +} + +impl<'a> InferenceContext<'a> { + pub fn new( + db: &'a DbIndex, + cache: &'a mut LuaInferCache, + call_expr: Option, + ) -> Self { + Self { + db, + cache, + call_expr, + infos: HashMap::new(), + } + } + + pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { + for tpl_id in tpl_ids { + if tpl_id.is_conditional_infer() { + continue; + } + + self.infos.entry(tpl_id).or_insert_with(InferenceInfo::new); + } + } + + pub fn has_unresolved_inference_slots(&self) -> bool { + self.infos.values().any(|info| { + info.fixed.is_none() + && info.covariant.is_empty() + && info.contravariant.is_empty() + && info.multi.is_empty() + }) + } + + pub fn fix_type(&mut self, tpl_id: GenericTplId, ty: LuaType) { + if tpl_id.is_conditional_infer() { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + info.fixed = Some(InferenceResult::Type(ty)); + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + + pub fn add_variadic_params( + &mut self, + tpl_id: GenericTplId, + params: Vec<(String, Option)>, + ) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + return; + } + + info.fixed = Some(InferenceResult::VariadicParams(params)); + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + + pub fn add_variadic_base(&mut self, tpl_id: GenericTplId, base: LuaType) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + return; + } + + info.fixed = Some(InferenceResult::VariadicBase(base)); + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + + pub fn insert_type( + &mut self, + tpl_id: GenericTplId, + candidate: InferenceCandidate, + variance: InferenceVariance, + top_level: bool, + priority: InferencePriority, + ) { + if !self.can_bind(tpl_id) { + return; + } + + self.infos + .entry(tpl_id) + .or_default() + .add_candidate(variance, candidate, top_level, priority); + } + + pub fn insert_multi_types( + &mut self, + tpl_id: GenericTplId, + types: Vec, + view: InferenceCandidateView, + top_level: bool, + priority: InferencePriority, + ) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + let fixed = if types.len() == 1 { + Some(InferenceResult::VariadicParams(vec![( + "var0".to_string(), + Some(types[0].clone()), + )])) + } else { + Some(InferenceResult::MultiTypes(types.clone())) + }; + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + if top_level { + info.fixed = fixed; + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + return; + } + + info.multi = types + .into_iter() + .map(|ty| InferenceRecord { + candidate: InferenceCandidate { ty, view }, + priority, + }) + .collect(); + info.top_level &= top_level; + info.priority = Some( + info.priority + .map_or(priority, |current| current.max(priority)), + ); + info.inferred = None; + } + + pub(super) fn inferred_variadic_len(&self, tpl_id: GenericTplId) -> Option { + let info = self.infos.get(&tpl_id)?; + if let Some(fixed) = &info.fixed { + return match fixed { + InferenceResult::Type(_) => Some(1), + InferenceResult::MultiTypes(types) => Some(types.len().max(1)), + InferenceResult::VariadicParams(params) => Some(params.len().max(1)), + InferenceResult::VariadicBase(_) => None, + }; + } + + if !info.multi.is_empty() { + return Some(info.multi.len().max(1)); + } + + if !info.covariant.is_empty() || !info.contravariant.is_empty() { + return Some(1); + } + + None + } + + pub fn fixing_mapper<'b>( + &mut self, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) -> TypeMapper { + self.mapper_inner(generic_tpls, return_type, true) + } + + pub fn non_fixing_mapper<'b>( + &mut self, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) -> TypeMapper { + self.mapper_inner(generic_tpls, return_type, false) + } + + fn mapper_inner<'b>( + &mut self, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + fixing: bool, + ) -> TypeMapper { + let generic_tpls = generic_tpls.into_iter().collect::>(); + let sources = generic_tpls + .iter() + .map(|tpl| tpl.get_tpl_id()) + .collect::>(); + let mut fallback_indices = HashMap::with_capacity(generic_tpls.len()); + let fallback_targets = generic_tpls + .iter() + .enumerate() + .map(|(index, tpl)| { + fallback_indices.entry(tpl.get_tpl_id()).or_insert(index); + self.infos + .get(&tpl.get_tpl_id()) + .and_then(|info| info.inferred.clone().or_else(|| info.fixed.clone())) + .map(inference_result_to_mapper_value) + .unwrap_or(TypeMapperValue::None) + }) + .collect::>(); + let fallback_targets = Rc::new(RefCell::new(fallback_targets)); + let fallback_mapper = TypeMapper::from_inference_fallback( + Rc::new(fallback_indices), + fallback_targets.clone(), + ); + let targets = (0..generic_tpls.len()) + .map(|index| { + let value = self.get_inferred_mapper_value( + &generic_tpls, + index, + return_type, + &fallback_mapper, + fixing, + ); + if value != TypeMapperValue::None + && let Some(target) = fallback_targets.borrow_mut().get_mut(index) + { + *target = value.clone(); + } + value + }) + .collect(); + TypeMapper::from_values(sources, targets) + } + + fn can_bind(&self, tpl_id: GenericTplId) -> bool { + !tpl_id.is_conditional_infer() + } + + fn get_inferred_mapper_value( + &mut self, + generic_tpls: &[&Arc], + index: usize, + return_type: &LuaType, + fallback_mapper: &TypeMapper, + fixing: bool, + ) -> TypeMapperValue { + let tpl = generic_tpls[index]; + let tpl_id = tpl.get_tpl_id(); + let return_top_level = is_tpl_at_top_level(self.db, return_type, tpl_id); + + if let Some(info) = self.infos.get_mut(&tpl_id) + && let Some(inferred) = info.inferred.clone() + { + if fixing { + info.is_fixed = true; + } + return inference_result_to_mapper_value(inferred); + } + + let inferred = self.infos.get_mut(&tpl_id).and_then(|info| { + if fixing { + info.is_fixed = true; + } + + if let Some(inferred) = &info.inferred { + return Some(inferred.clone()); + } + + let result = resolve_info(self.db, tpl, info, return_top_level, fallback_mapper)?; + info.inferred = Some(result.clone()); + Some(result) + }); + let value = if let Some(inferred) = inferred { + let value = inference_result_to_mapper_value(inferred); + apply_inferred_constraint(self.db, tpl, value, fallback_mapper) + } else { + TypeMapperValue::None + }; + + if value != TypeMapperValue::None + && let Some(info) = self.infos.get_mut(&tpl_id) + { + info.inferred = Some(mapper_value_to_inference_result(value.clone())); + } + + value + } +} + +pub(super) fn inference_result_to_mapper_value(result: InferenceResult) -> TypeMapperValue { + match result { + InferenceResult::Type(ty) => TypeMapperValue::type_value(ty), + InferenceResult::MultiTypes(types) => TypeMapperValue::MultiTypes( + types + .into_iter() + .map(|ty| { + TypeMapperValue::type_value(ty) + .raw_type() + .unwrap_or(LuaType::Unknown) + }) + .collect(), + ), + InferenceResult::VariadicParams(params) => TypeMapperValue::params_value(params), + InferenceResult::VariadicBase(base) => TypeMapperValue::MultiBase( + TypeMapperValue::type_value(base) + .raw_type() + .unwrap_or(LuaType::Unknown), + ), + } +} + +fn mapper_value_to_inference_result(value: TypeMapperValue) -> InferenceResult { + match value { + TypeMapperValue::None => InferenceResult::Type(LuaType::Unknown), + TypeMapperValue::Type(ty) => InferenceResult::Type(ty), + TypeMapperValue::Params(params) => InferenceResult::VariadicParams(params), + TypeMapperValue::MultiTypes(types) => InferenceResult::MultiTypes(types), + TypeMapperValue::MultiBase(base) => InferenceResult::VariadicBase(base), + } +} + +fn apply_inferred_constraint( + db: &DbIndex, + tpl: &GenericTpl, + value: TypeMapperValue, + mapper: &TypeMapper, +) -> TypeMapperValue { + let TypeMapperValue::Type(inferred_type) = &value else { + return value; + }; + + let Some(constraint) = tpl.get_constraint() else { + return value; + }; + + let instantiated_constraint = instantiate_type_generic(db, constraint, mapper); + if inferred_satisfies_constraint(db, inferred_type, &instantiated_constraint) { + return value; + } + + value +} + +fn inferred_satisfies_constraint(db: &DbIndex, inferred: &LuaType, constraint: &LuaType) -> bool { + if inferred == constraint || check_type_compact(db, inferred, constraint).is_ok() { + return true; + } + + match constraint { + LuaType::Union(union) => { + return union + .into_vec() + .iter() + .any(|member| inferred_satisfies_constraint(db, inferred, member)); + } + LuaType::MultiLineUnion(union) => { + return union + .get_unions() + .iter() + .any(|(member, _)| inferred_satisfies_constraint(db, inferred, member)); + } + _ => {} + } + + let regular = regularize_tpl_candidate_type(db, inferred.clone()); + regular != *inferred && inferred_satisfies_constraint(db, ®ular, constraint) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs similarity index 70% rename from crates/emmylua_code_analysis/src/semantic/generic/inference.rs rename to crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs index 896eb00a4..cd1a57133 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/inference.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs @@ -1,399 +1,25 @@ use std::{collections::HashMap as StdHashMap, ops::Deref, sync::Arc}; -use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr}; -use hashbrown::{HashMap, HashSet}; +use emmylua_parser::{LuaAstNode, LuaExpr}; use itertools::Itertools; use rowan::NodeOrToken; use smol_str::SmolStr; use crate::{ - DbIndex, GenericTpl, GenericTplId, InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, - LuaGenericType, LuaInferCache, LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, - LuaTupleType, LuaType, LuaTypeDeclId, LuaTypeNode, LuaUnionType, SemanticDeclLevel, TypeOps, - TypeSubstitutor, VariadicType, check_type_compact, infer_node_semantic_decl, + InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, LuaGenericType, LuaMemberInfo, + LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, LuaTupleType, LuaType, LuaTypeNode, + LuaUnionType, SemanticDeclLevel, VariadicType, check_type_compact, infer_node_semantic_decl, instantiate_type_generic, semantic::{ - generic::{ - is_primitive_or_literal_type, regularize_tpl_candidate_type, widen_tpl_candidate_type, - }, + generic::TypeMapper, member::{find_index_operations, get_member_map}, }, }; -use super::type_substitutor::TplBinding; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(in crate::semantic::generic) enum InferenceCandidateView { - FreshExpression, - RegularType, - ConstPreserving, - Ordinary, -} - -#[derive(Debug, Clone, PartialEq)] -pub(in crate::semantic::generic) struct InferenceCandidate { - ty: LuaType, - view: InferenceCandidateView, -} - -impl InferenceCandidate { - pub(in crate::semantic::generic) fn from_expr_arg(expr: Option<&LuaExpr>, ty: LuaType) -> Self { - if is_literal_candidate(&ty) { - if expr.is_some_and(is_fresh_literal_expr) { - return Self::fresh_expression(ty); - } - - return Self::regular_type(ty); - } - - Self::ordinary(ty) - } - - pub(in crate::semantic::generic) fn regular_type(ty: LuaType) -> Self { - Self { - ty, - view: InferenceCandidateView::RegularType, - } - } - - pub(in crate::semantic::generic) fn const_preserving(ty: LuaType) -> Self { - Self { - ty, - view: InferenceCandidateView::ConstPreserving, - } - } - - pub(in crate::semantic::generic) fn ordinary(ty: LuaType) -> Self { - Self { - ty, - view: InferenceCandidateView::Ordinary, - } - } - - fn fresh_expression(ty: LuaType) -> Self { - Self { - ty, - view: InferenceCandidateView::FreshExpression, - } - } - - fn candidate_type(&self) -> LuaType { - self.ty.clone() - } - - fn is_const_preserving(&self) -> bool { - self.view == InferenceCandidateView::ConstPreserving - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub(in crate::semantic::generic) enum InferencePriority { - Normal, - Return, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(in crate::semantic::generic) enum InferenceVariance { - Covariant, - Contravariant, -} - -impl InferenceVariance { - fn flip(self) -> Self { - match self { - InferenceVariance::Covariant => InferenceVariance::Contravariant, - InferenceVariance::Contravariant => InferenceVariance::Covariant, - } - } -} - -#[derive(Debug, Clone)] -pub(in crate::semantic::generic) enum InferenceResult { - Type(LuaType), - MultiTypes(Vec), - VariadicParams(Vec<(String, Option)>), - VariadicBase(LuaType), -} - -#[derive(Debug, Clone)] -struct InferenceRecord { - candidate: InferenceCandidate, - priority: InferencePriority, -} - -#[derive(Debug, Clone, Default)] -pub(in crate::semantic::generic) struct InferenceInfo { - // 协变候选, 例如参数位置或返回值位置的正向推断. - covariant: Vec, - // 逆变候选, 例如函数参数约束的反向传播. - contravariant: Vec, - // 多返回值或变参展开时收集的候选. - multi: Vec, - // 已经固定下来的结果, 例如显式泛型或变参参数表. - fixed: Option, - // 最近一次写入的推断结果, 便于在后续阶段判断是否需要重算. - inferred: Option, - // 当前信息是否始终处于顶层位置. - top_level: bool, - // 该模板收到的最高优先级. - priority: Option, -} - -impl InferenceInfo { - fn new() -> Self { - Self { - top_level: true, - ..Self::default() - } - } - - fn add_candidate( - &mut self, - variance: InferenceVariance, - candidate: InferenceCandidate, - top_level: bool, - priority: InferencePriority, - ) { - self.top_level &= top_level; - self.priority = Some( - self.priority - .map_or(priority, |current| current.max(priority)), - ); - self.inferred = None; - let record = InferenceRecord { - candidate, - priority, - }; - match variance { - InferenceVariance::Covariant => self.covariant.push(record), - InferenceVariance::Contravariant => self.contravariant.push(record), - } - } -} - -#[derive(Debug)] -pub(in crate::semantic::generic) struct InferenceContext<'a> { - pub db: &'a DbIndex, - pub cache: &'a mut LuaInferCache, - pub call_expr: Option, - infos: HashMap, -} - -impl<'a> InferenceContext<'a> { - pub fn new( - db: &'a DbIndex, - cache: &'a mut LuaInferCache, - call_expr: Option, - ) -> Self { - Self { - db, - cache, - call_expr, - infos: HashMap::new(), - } - } - - pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { - for tpl_id in tpl_ids { - if tpl_id.is_conditional_infer() { - continue; - } - - self.infos.entry(tpl_id).or_insert_with(InferenceInfo::new); - } - } - - pub fn has_unresolved_inference_slots(&self) -> bool { - self.infos.values().any(|info| { - info.fixed.is_none() - && info.covariant.is_empty() - && info.contravariant.is_empty() - && info.multi.is_empty() - }) - } - - pub fn fix_type(&mut self, tpl_id: GenericTplId, ty: LuaType) { - if tpl_id.is_conditional_infer() { - return; - } - - let info = self.infos.entry(tpl_id).or_default(); - info.fixed = Some(InferenceResult::Type(ty)); - info.inferred = info.fixed.clone(); - } - - pub fn add_variadic_params( - &mut self, - tpl_id: GenericTplId, - params: Vec<(String, Option)>, - ) { - if !self.can_bind(tpl_id) { - return; - } - - let info = self.infos.entry(tpl_id).or_default(); - if info.fixed.is_some() - || !info.multi.is_empty() - || !info.covariant.is_empty() - || !info.contravariant.is_empty() - { - return; - } - - info.fixed = Some(InferenceResult::VariadicParams(params)); - info.inferred = info.fixed.clone(); - } - - pub fn add_variadic_base(&mut self, tpl_id: GenericTplId, base: LuaType) { - if !self.can_bind(tpl_id) { - return; - } - - let info = self.infos.entry(tpl_id).or_default(); - if info.fixed.is_some() - || !info.multi.is_empty() - || !info.covariant.is_empty() - || !info.contravariant.is_empty() - { - return; - } - - info.fixed = Some(InferenceResult::VariadicBase(base)); - info.inferred = info.fixed.clone(); - } - - pub fn insert_type( - &mut self, - tpl_id: GenericTplId, - candidate: InferenceCandidate, - variance: InferenceVariance, - top_level: bool, - priority: InferencePriority, - ) { - if !self.can_bind(tpl_id) { - return; - } - - self.infos - .entry(tpl_id) - .or_default() - .add_candidate(variance, candidate, top_level, priority); - } - - pub fn insert_multi_types( - &mut self, - tpl_id: GenericTplId, - types: Vec, - view: InferenceCandidateView, - top_level: bool, - priority: InferencePriority, - ) { - if !self.can_bind(tpl_id) { - return; - } - - let info = self.infos.entry(tpl_id).or_default(); - let fixed = if types.len() == 1 { - Some(InferenceResult::VariadicParams(vec![( - "var0".to_string(), - Some(types[0].clone()), - )])) - } else { - Some(InferenceResult::MultiTypes(types.clone())) - }; - if info.fixed.is_some() - || !info.multi.is_empty() - || !info.covariant.is_empty() - || !info.contravariant.is_empty() - { - if top_level { - info.fixed = fixed; - info.inferred = info.fixed.clone(); - } - return; - } - - info.multi = types - .into_iter() - .map(|ty| InferenceRecord { - candidate: InferenceCandidate { ty, view }, - priority, - }) - .collect(); - info.top_level &= top_level; - info.priority = Some( - info.priority - .map_or(priority, |current| current.max(priority)), - ); - info.inferred = None; - } - - fn inferred_variadic_len(&self, tpl_id: GenericTplId) -> Option { - let info = self.infos.get(&tpl_id)?; - if let Some(fixed) = &info.fixed { - return match fixed { - InferenceResult::Type(_) => Some(1), - InferenceResult::MultiTypes(types) => Some(types.len().max(1)), - InferenceResult::VariadicParams(params) => Some(params.len().max(1)), - InferenceResult::VariadicBase(_) => None, - }; - } - - if !info.multi.is_empty() { - return Some(info.multi.len().max(1)); - } - - if !info.covariant.is_empty() || !info.contravariant.is_empty() { - return Some(1); - } - - None - } - - pub fn bridge_to_substitutor<'b>( - &mut self, - substitutor: &mut TypeSubstitutor, - generic_tpls: impl IntoIterator>, - return_type: &LuaType, - ) { - let generic_tpls = generic_tpls.into_iter().collect::>(); - let tpl_ids = generic_tpls.iter().map(|tpl| tpl.get_tpl_id()).collect(); - substitutor.prepare_inference_slots(tpl_ids); - self.bridge_to_substitutor_inner(substitutor, generic_tpls, return_type); - } - - pub fn bridge_resolved_to_substitutor<'b>( - &mut self, - substitutor: &mut TypeSubstitutor, - generic_tpls: impl IntoIterator>, - return_type: &LuaType, - ) { - self.bridge_to_substitutor_inner(substitutor, generic_tpls, return_type); - } - - fn bridge_to_substitutor_inner<'b>( - &mut self, - substitutor: &mut TypeSubstitutor, - generic_tpls: impl IntoIterator>, - return_type: &LuaType, - ) { - for tpl in generic_tpls { - let tpl_id = tpl.get_tpl_id(); - let return_top_level = is_tpl_at_top_level(self.db, return_type, tpl_id); - let result = self - .infos - .get(&tpl_id) - .and_then(|info| resolve_info(self.db, tpl, info, return_top_level, substitutor)); - if let Some(result) = result { - write_result_to_substitutor(substitutor, tpl_id, result); - } - } - } - - fn can_bind(&self, tpl_id: GenericTplId) -> bool { - !tpl_id.is_conditional_infer() - } -} +use super::{ + InferenceCandidate, InferenceCandidateView, InferenceContext, InferencePriority, + InferenceVariance, escape_alias, get_str_tpl_infer_type, +}; pub(in crate::semantic::generic) fn infer_types( context: &mut InferenceContext, @@ -1007,14 +633,12 @@ fn generic_infer_types( .get_type_decl(target_base) .ok_or(InferFailReason::None)?; if target_decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( + let mapper = TypeMapper::from_alias( context.db, target_generic.get_params().clone(), - target_base.clone(), + target_base, ); - if let Some(origin_type) = - target_decl.get_alias_origin(context.db, Some(&substitutor)) - { + if let Some(origin_type) = target_decl.get_alias_origin(context.db, Some(&mapper)) { return generic_infer_types( context, source_generic, @@ -1031,10 +655,9 @@ fn generic_infer_types( { for mut super_type in super_types { if super_type.contains_tpl_node() { - let substitutor = - TypeSubstitutor::from_type_array(target_generic.get_params().clone()); - super_type = - instantiate_type_generic(context.db, &super_type, &substitutor); + let mapper = + TypeMapper::from_type_array(target_generic.get_params().clone()); + super_type = instantiate_type_generic(context.db, &super_type, &mapper); } generic_infer_types( context, @@ -1102,9 +725,9 @@ fn generic_infer_types( } } _ => { - let substitutor = TypeSubstitutor::new(); + let mapper = TypeMapper::empty(); let generic_ty = LuaType::Generic(source_generic.clone().into()); - let ty = instantiate_type_generic(context.db, &generic_ty, &substitutor); + let ty = instantiate_type_generic(context.db, &generic_ty, &mapper); if LuaType::from(source_generic.clone()) != ty { infer_types_inner( context, @@ -1935,289 +1558,3 @@ fn check_lambda_inference( Err(InferFailReason::UnResolveSignatureReturn(signature_id)) } - -fn resolve_info( - db: &DbIndex, - tpl: &GenericTpl, - info: &InferenceInfo, - return_top_level: bool, - substitutor: &TypeSubstitutor, -) -> Option { - if let Some(fixed) = &info.fixed { - return Some(fixed.clone()); - } - - if !info.multi.is_empty() { - let primitive_constraint = tpl.get_constraint().is_some_and(|constraint| { - let constraint = instantiate_type_generic(db, constraint, substitutor); - is_primitive_or_literal_type(&constraint) - }); - let const_preserving = info - .multi - .iter() - .any(|record| record.candidate.is_const_preserving()); - let preserve_root_literal_form = - primitive_constraint || const_preserving || return_top_level; - return Some(InferenceResult::MultiTypes( - info.multi - .iter() - .map(|record| { - resolve_candidate_type(db, &record.candidate, preserve_root_literal_form) - }) - .collect(), - )); - } - - if !info.covariant.is_empty() { - return Some(InferenceResult::Type(resolve_covariant_candidates( - db, - tpl, - info, - return_top_level, - substitutor, - ))); - } - - if !info.contravariant.is_empty() { - return Some(InferenceResult::Type(resolve_contravariant_candidates( - db, info, - ))); - } - - None -} - -fn write_result_to_substitutor( - substitutor: &mut TypeSubstitutor, - tpl_id: GenericTplId, - result: InferenceResult, -) { - match result { - InferenceResult::Type(ty) => substitutor.bind_type(tpl_id, ty), - InferenceResult::MultiTypes(types) => { - substitutor.bind(tpl_id, TplBinding::InferredMultiTypes(types)); - } - InferenceResult::VariadicParams(params) => { - substitutor.bind(tpl_id, TplBinding::VariadicParams(params)); - } - InferenceResult::VariadicBase(base) => { - substitutor.bind(tpl_id, TplBinding::VariadicBase(base)); - } - } -} - -fn resolve_covariant_candidates( - db: &DbIndex, - tpl: &GenericTpl, - info: &InferenceInfo, - return_top_level: bool, - substitutor: &TypeSubstitutor, -) -> LuaType { - let primitive_constraint = tpl - .get_constraint() - .map(|constraint| { - let constraint = instantiate_type_generic(db, constraint, substitutor); - is_primitive_or_literal_type(&constraint) - }) - .unwrap_or(false); - let const_preserving = info - .covariant - .iter() - .any(|record| record.candidate.is_const_preserving()); - let preserve_root_literal_form = - primitive_constraint || const_preserving || !info.top_level || return_top_level; - - combine_records( - db, - &info.covariant, - info.priority.unwrap_or(InferencePriority::Normal), - preserve_root_literal_form, - TypeOps::Union, - ) -} - -fn resolve_contravariant_candidates(db: &DbIndex, info: &InferenceInfo) -> LuaType { - combine_records( - db, - &info.contravariant, - info.priority.unwrap_or(InferencePriority::Normal), - true, - TypeOps::Intersect, - ) -} - -fn combine_records( - db: &DbIndex, - records: &[InferenceRecord], - max_priority: InferencePriority, - preserve_root_literal_form: bool, - op: TypeOps, -) -> LuaType { - let mut selected = records - .iter() - .filter(|record| record.priority == max_priority) - .map(|record| resolve_candidate_type(db, &record.candidate, preserve_root_literal_form)); - - let Some(first) = selected.next() else { - return LuaType::Unknown; - }; - - selected.fold(first, |acc, ty| op.apply(db, &acc, &ty)) -} - -fn resolve_candidate_type( - db: &DbIndex, - candidate: &InferenceCandidate, - preserve_root_literal_form: bool, -) -> LuaType { - if candidate.is_const_preserving() { - // std.ConstTpl 需要保留结构字面量, 例如 tuple/object/table const. - return candidate.candidate_type(); - } - - if preserve_root_literal_form { - return regularize_tpl_candidate_type(db, candidate.candidate_type()); - } - - match candidate.view { - InferenceCandidateView::FreshExpression | InferenceCandidateView::Ordinary => { - widen_tpl_candidate_type(db, candidate.candidate_type()) - } - _ => regularize_tpl_candidate_type(db, candidate.candidate_type()), - } -} - -pub(in crate::semantic::generic) fn is_literal_candidate(ty: &LuaType) -> bool { - match ty { - LuaType::StringConst(_) - | LuaType::DocStringConst(_) - | LuaType::IntegerConst(_) - | LuaType::DocIntegerConst(_) - | LuaType::FloatConst(_) - | LuaType::BooleanConst(_) - | LuaType::DocBooleanConst(_) - | LuaType::TableConst(_) => true, - LuaType::Union(union) => union.into_vec().iter().any(is_literal_candidate), - LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_literal_candidate), - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => is_literal_candidate(base), - VariadicType::Multi(types) => types.iter().any(is_literal_candidate), - }, - _ => false, - } -} - -fn is_fresh_literal_expr(expr: &LuaExpr) -> bool { - match expr { - LuaExpr::LiteralExpr(_) | LuaExpr::TableExpr(_) => true, - LuaExpr::ParenExpr(paren) => paren - .get_expr() - .is_some_and(|expr| is_fresh_literal_expr(&expr)), - _ => false, - } -} - -fn get_str_tpl_infer_type(name: &str) -> LuaType { - match name { - "unknown" => LuaType::Unknown, - "never" => LuaType::Never, - "nil" | "void" => LuaType::Nil, - "any" => LuaType::Any, - "userdata" => LuaType::Userdata, - "thread" => LuaType::Thread, - "boolean" | "bool" => LuaType::Boolean, - "string" => LuaType::String, - "integer" | "int" => LuaType::Integer, - "number" => LuaType::Number, - "io" => LuaType::Io, - "self" => LuaType::SelfInfer, - "global" => LuaType::Global, - "function" => LuaType::Function, - _ => LuaType::Ref(LuaTypeDeclId::global(name)), - } -} - -fn escape_alias(db: &DbIndex, may_alias: &LuaType) -> LuaType { - if let LuaType::Ref(type_id) = may_alias - && let Some(type_decl) = db.get_type_index().get_type_decl(type_id) - && type_decl.is_alias() - && let Some(origin_type) = type_decl.get_alias_origin(db, None) - { - return origin_type.clone(); - } - - may_alias.clone() -} - -fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { - is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) -} - -fn is_tpl_at_top_level_with_guard( - db: &DbIndex, - ty: &LuaType, - tpl_id: GenericTplId, - visited_aliases: &mut HashSet, -) -> bool { - match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, - LuaType::Union(union) => union.into_vec().iter().any(|member| { - let mut branch_aliases = visited_aliases.clone(); - is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) - }), - LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { - let mut branch_aliases = visited_aliases.clone(); - is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) - }), - LuaType::Generic(generic) => { - let type_decl_id = generic.get_base_type_id_ref(); - let Some(alias_param) = - get_transparent_alias_param_index(db, type_decl_id, visited_aliases) - else { - return false; - }; - - generic.get_params().get(alias_param).is_some_and(|param| { - is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) - }) - } - _ => false, - } -} - -fn get_transparent_alias_param_index( - db: &DbIndex, - type_decl_id: &LuaTypeDeclId, - visited_aliases: &mut HashSet, -) -> Option { - if !visited_aliases.insert(type_decl_id.clone()) { - return None; - } - - let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; - if !type_decl.is_alias() { - return None; - }; - let origin = type_decl.get_alias_ref()?; - - match origin { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => - { - Some(tpl.get_tpl_id().get_idx()) - } - LuaType::Generic(generic) => { - get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) - .and_then(|alias_param| generic.get_params().get(alias_param)) - .and_then(|param| match param { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => - { - Some(tpl.get_tpl_id().get_idx()) - } - _ => None, - }) - } - _ => None, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs new file mode 100644 index 000000000..d9ae9752d --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs @@ -0,0 +1,247 @@ +use std::ops::Deref; + +use emmylua_parser::LuaExpr; +use hashbrown::HashSet; + +use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId, VariadicType}; + +mod context; +mod infer_types; +mod resolve; + +#[cfg(test)] +mod tests; + +pub(in crate::semantic::generic) use context::InferenceContext; +pub(in crate::semantic::generic) use infer_types::{ + infer_type_list, infer_types_from_expr, multi_param_infer_multi_return, + return_type_infer_types, variadic_infer_types, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum InferenceCandidateView { + FreshExpression, + RegularType, + ConstPreserving, + Ordinary, +} + +#[derive(Debug, Clone, PartialEq)] +pub(in crate::semantic::generic) struct InferenceCandidate { + pub(super) ty: LuaType, + pub(super) view: InferenceCandidateView, +} + +impl InferenceCandidate { + pub(in crate::semantic::generic) fn from_expr_arg(expr: Option<&LuaExpr>, ty: LuaType) -> Self { + if is_literal_candidate(&ty) { + if expr.is_some_and(is_fresh_literal_expr) { + return Self::fresh_expression(ty); + } + + return Self::regular_type(ty); + } + + Self::ordinary(ty) + } + + pub(in crate::semantic::generic) fn regular_type(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::RegularType, + } + } + + pub(in crate::semantic::generic) fn const_preserving(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::ConstPreserving, + } + } + + pub(in crate::semantic::generic) fn ordinary(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::Ordinary, + } + } + + fn fresh_expression(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::FreshExpression, + } + } + + pub(super) fn candidate_type(&self) -> LuaType { + self.ty.clone() + } + + pub(super) fn is_const_preserving(&self) -> bool { + self.view == InferenceCandidateView::ConstPreserving + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(in crate::semantic::generic) enum InferencePriority { + Normal, + Return, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum InferenceVariance { + Covariant, + Contravariant, +} + +impl InferenceVariance { + pub(in crate::semantic::generic) fn flip(self) -> Self { + match self { + InferenceVariance::Covariant => InferenceVariance::Contravariant, + InferenceVariance::Contravariant => InferenceVariance::Covariant, + } + } +} + +#[derive(Debug, Clone)] +pub(in crate::semantic::generic) enum InferenceResult { + Type(LuaType), + MultiTypes(Vec), + VariadicParams(Vec<(String, Option)>), + VariadicBase(LuaType), +} + +pub(in crate::semantic::generic) fn is_literal_candidate(ty: &LuaType) -> bool { + match ty { + LuaType::StringConst(_) + | LuaType::DocStringConst(_) + | LuaType::IntegerConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::BooleanConst(_) + | LuaType::DocBooleanConst(_) + | LuaType::TableConst(_) => true, + LuaType::Union(union) => union.into_vec().iter().any(is_literal_candidate), + LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_literal_candidate), + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => is_literal_candidate(base), + VariadicType::Multi(types) => types.iter().any(is_literal_candidate), + }, + _ => false, + } +} + +fn is_fresh_literal_expr(expr: &LuaExpr) -> bool { + match expr { + LuaExpr::LiteralExpr(_) | LuaExpr::TableExpr(_) => true, + LuaExpr::ParenExpr(paren) => paren + .get_expr() + .is_some_and(|expr| is_fresh_literal_expr(&expr)), + _ => false, + } +} + +pub(in crate::semantic::generic) fn get_str_tpl_infer_type(name: &str) -> LuaType { + match name { + "unknown" => LuaType::Unknown, + "never" => LuaType::Never, + "nil" | "void" => LuaType::Nil, + "any" => LuaType::Any, + "userdata" => LuaType::Userdata, + "thread" => LuaType::Thread, + "boolean" | "bool" => LuaType::Boolean, + "string" => LuaType::String, + "integer" | "int" => LuaType::Integer, + "number" => LuaType::Number, + "io" => LuaType::Io, + "self" => LuaType::SelfInfer, + "global" => LuaType::Global, + "function" => LuaType::Function, + _ => LuaType::Ref(LuaTypeDeclId::global(name)), + } +} + +pub(in crate::semantic::generic) fn escape_alias(db: &DbIndex, may_alias: &LuaType) -> LuaType { + if let LuaType::Ref(type_id) = may_alias + && let Some(type_decl) = db.get_type_index().get_type_decl(type_id) + && type_decl.is_alias() + && let Some(origin_type) = type_decl.get_alias_origin(db, None) + { + return origin_type.clone(); + } + + may_alias.clone() +} + +pub(super) fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { + is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) +} + +fn is_tpl_at_top_level_with_guard( + db: &DbIndex, + ty: &LuaType, + tpl_id: GenericTplId, + visited_aliases: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::Union(union) => union.into_vec().iter().any(|member| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::Generic(generic) => { + let type_decl_id = generic.get_base_type_id_ref(); + let Some(alias_param) = + get_transparent_alias_param_index(db, type_decl_id, visited_aliases) + else { + return false; + }; + + generic.get_params().get(alias_param).is_some_and(|param| { + is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) + }) + } + _ => false, + } +} + +fn get_transparent_alias_param_index( + db: &DbIndex, + type_decl_id: &LuaTypeDeclId, + visited_aliases: &mut HashSet, +) -> Option { + if !visited_aliases.insert(type_decl_id.clone()) { + return None; + } + + let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; + if !type_decl.is_alias() { + return None; + }; + let origin = type_decl.get_alias_ref()?; + + match origin { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + LuaType::Generic(generic) => { + get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) + .and_then(|alias_param| generic.get_params().get(alias_param)) + .and_then(|param| match param { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + _ => None, + }) + } + _ => None, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs new file mode 100644 index 000000000..07135328e --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs @@ -0,0 +1,135 @@ +use crate::{DbIndex, GenericTpl, LuaType, TypeOps, instantiate_type_generic}; + +use crate::semantic::generic::{ + TypeMapper, is_primitive_or_literal_type, regularize_tpl_candidate_type, + widen_tpl_candidate_type, +}; + +use super::context::{InferenceInfo, InferenceRecord}; +use super::{InferenceCandidate, InferenceCandidateView, InferencePriority, InferenceResult}; + +pub(super) fn resolve_info( + db: &DbIndex, + tpl: &GenericTpl, + info: &InferenceInfo, + return_top_level: bool, + mapper: &TypeMapper, +) -> Option { + if let Some(fixed) = &info.fixed { + return Some(fixed.clone()); + } + + if !info.multi.is_empty() { + let primitive_constraint = tpl.get_constraint().is_some_and(|constraint| { + let constraint = instantiate_type_generic(db, constraint, mapper); + is_primitive_or_literal_type(&constraint) + }); + let const_preserving = info + .multi + .iter() + .any(|record| record.candidate.is_const_preserving()); + let preserve_root_literal_form = + primitive_constraint || const_preserving || return_top_level; + return Some(InferenceResult::MultiTypes( + info.multi + .iter() + .map(|record| { + resolve_candidate_type(db, &record.candidate, preserve_root_literal_form) + }) + .collect(), + )); + } + + if !info.covariant.is_empty() { + return Some(InferenceResult::Type(resolve_covariant_candidates( + db, + tpl, + info, + return_top_level, + mapper, + ))); + } + + if !info.contravariant.is_empty() { + return Some(InferenceResult::Type(combine_records( + db, + &info.contravariant, + info.priority.unwrap_or(InferencePriority::Normal), + true, + TypeOps::Intersect, + ))); + } + + None +} + +fn resolve_covariant_candidates( + db: &DbIndex, + tpl: &GenericTpl, + info: &InferenceInfo, + return_top_level: bool, + mapper: &TypeMapper, +) -> LuaType { + let primitive_constraint = tpl + .get_constraint() + .map(|constraint| { + let constraint = instantiate_type_generic(db, constraint, mapper); + is_primitive_or_literal_type(&constraint) + }) + .unwrap_or(false); + let const_preserving = info + .covariant + .iter() + .any(|record| record.candidate.is_const_preserving()); + let preserve_root_literal_form = + primitive_constraint || const_preserving || !info.top_level || return_top_level; + + combine_records( + db, + &info.covariant, + info.priority.unwrap_or(InferencePriority::Normal), + preserve_root_literal_form, + TypeOps::Union, + ) +} + +fn combine_records( + db: &DbIndex, + records: &[InferenceRecord], + max_priority: InferencePriority, + preserve_root_literal_form: bool, + op: TypeOps, +) -> LuaType { + let mut selected = records + .iter() + .filter(|record| record.priority == max_priority) + .map(|record| resolve_candidate_type(db, &record.candidate, preserve_root_literal_form)); + + let Some(first) = selected.next() else { + return LuaType::Unknown; + }; + + selected.fold(first, |acc, ty| op.apply(db, &acc, &ty)) +} + +fn resolve_candidate_type( + db: &DbIndex, + candidate: &InferenceCandidate, + preserve_root_literal_form: bool, +) -> LuaType { + if candidate.is_const_preserving() { + // std.ConstTpl 需要保留结构字面量, 例如 tuple/object/table const. + return candidate.candidate_type(); + } + + if preserve_root_literal_form { + return regularize_tpl_candidate_type(db, candidate.candidate_type()); + } + + match candidate.view { + InferenceCandidateView::FreshExpression | InferenceCandidateView::Ordinary => { + widen_tpl_candidate_type(db, candidate.candidate_type()) + } + _ => regularize_tpl_candidate_type(db, candidate.candidate_type()), + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs new file mode 100644 index 000000000..b826377c0 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use hashbrown::HashSet; +use smol_str::SmolStr; + +use super::{ + InferenceCandidate, InferenceContext, InferencePriority, InferenceResult, InferenceVariance, + context::inference_result_to_mapper_value, +}; +use crate::{ + CacheOptions, DbIndex, FileId, GenericTpl, GenericTplId, LuaInferCache, LuaType, LuaTypeDeclId, + TypeOps, + semantic::generic::{TypeMapperValue, get_mapped_value}, +}; + +#[test] +fn non_fixing_mapper_does_not_lock_later_candidates() { + let db = DbIndex::new(); + let mut cache = LuaInferCache::new(FileId::VIRTUAL, CacheOptions::default()); + let mut context = InferenceContext::new(&db, &mut cache, None); + let tpl_id = GenericTplId::Func(0); + let tpl = Arc::new(GenericTpl::new( + tpl_id, + SmolStr::new("T").into(), + None, + None, + )); + let return_type = LuaType::TplRef(tpl.clone()); + + context.prepare_inference_slots(HashSet::from([tpl_id])); + + // non_fixing_mapper 会读取当前推断结果, 但不应把泛型槽位标记为固定. + context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::String), + InferenceVariance::Covariant, + true, + InferencePriority::Normal, + ); + + let non_fixing = context.non_fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &non_fixing).and_then(|value| value.raw_type()), + Some(LuaType::String) + ); + + // 如果上面的 mapper 错误地固定了 T, 这里的新候选会被忽略; + // 正确行为是继续合并协变候选, 得到 string | integer. + context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Integer), + InferenceVariance::Covariant, + true, + InferencePriority::Normal, + ); + let expected = TypeOps::Union.apply(&db, &LuaType::String, &LuaType::Integer); + let fixing = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &fixing).and_then(|value| value.raw_type()), + Some(expected.clone()) + ); + + // fixing_mapper 才会真正锁定槽位; 锁定后再插入候选不应改变结果. + context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Boolean), + InferenceVariance::Covariant, + true, + InferencePriority::Normal, + ); + let fixed = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &fixed).and_then(|value| value.raw_type()), + Some(expected) + ); +} + +#[test] +fn variadic_params_mapper_normalizes_def_types() { + let type_id = LuaTypeDeclId::global("Alias"); + let value = inference_result_to_mapper_value(InferenceResult::VariadicParams(vec![( + "value".to_string(), + Some(LuaType::Def(type_id.clone())), + )])); + + assert_eq!( + value, + TypeMapperValue::Params(vec![("value".to_string(), Some(LuaType::Ref(type_id)))]) + ); +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index de85aa424..0b11a8c68 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -7,7 +7,7 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, - semantic::generic::type_substitutor::TypeSubstitutor, + semantic::generic::{TypeMapper, TypeMapperValue}, }; use super::instantiate_type_generic; @@ -90,13 +90,15 @@ fn complete_type_generic_args_inner( } let mut params = Vec::with_capacity(generic_params.len().max(provided_args.len())); - let mut substitutor = TypeSubstitutor::new(); + let mut prefix_sources = Vec::with_capacity(generic_params.len()); + let mut prefix_targets = Vec::with_capacity(generic_params.len()); let mut missing_required_count = 0; let mut cycled = false; for (idx, generic_param) in generic_params.iter().enumerate() { if let Some(provided_arg) = provided_args.get(idx) { let provided_arg = provided_arg.clone(); - substitutor.bind_type(GenericTplId::Type(idx as u32), provided_arg.clone()); + prefix_sources.push(GenericTplId::Type(idx as u32)); + prefix_targets.push(provided_arg.clone()); params.push(provided_arg); continue; } @@ -114,8 +116,17 @@ fn complete_type_generic_args_inner( } else { completed_type.ty }; - let instantiated = instantiate_type_generic(db, &default_type, &substitutor); - substitutor.bind_type(GenericTplId::Type(idx as u32), instantiated.clone()); + let mapper = TypeMapper::from_values( + prefix_sources.clone(), + prefix_targets + .iter() + .cloned() + .map(TypeMapperValue::type_value) + .collect(), + ); + let instantiated = instantiate_type_generic(db, &default_type, &mapper); + prefix_sources.push(GenericTplId::Type(idx as u32)); + prefix_targets.push(instantiated.clone()); params.push(instantiated); } else { missing_required_count += 1; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs new file mode 100644 index 000000000..316a19f5e --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs @@ -0,0 +1,115 @@ +use std::sync::Arc; + +use crate::{DbIndex, LuaType, LuaTypeDeclId}; + +use super::super::TypeMapper; + +const MAX_INSTANTIATION_DEPTH: usize = 128; +const MAX_ALIAS_STACK: usize = 32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum UninferredTplPolicy { + /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. + Fallback, + /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. + PreserveTplRef, +} + +#[derive(Debug)] +pub(in crate::semantic::generic) struct GenericInstantiateContext<'a> { + pub db: &'a DbIndex, + pub mapper: TypeMapper, + self_type: Option<&'a LuaType>, + alias_stack: Arc<[LuaTypeDeclId]>, +} + +#[derive(Debug, Clone, Copy)] +pub(in crate::semantic::generic) struct GenericInstantiateFrame { + policy: UninferredTplPolicy, + depth: usize, +} + +impl<'a> GenericInstantiateContext<'a> { + pub(super) fn new( + db: &'a DbIndex, + mapper: &TypeMapper, + self_type: Option<&'a LuaType>, + ) -> Self { + Self { + db, + mapper: mapper.clone(), + self_type, + alias_stack: Arc::from([]), + } + } + + pub(super) fn root_frame(&self) -> GenericInstantiateFrame { + GenericInstantiateFrame { + policy: UninferredTplPolicy::Fallback, + depth: 0, + } + } + + pub(super) fn self_type(&self) -> Option<&LuaType> { + self.self_type + } + + pub(super) fn with_mapper(&self, mapper: TypeMapper) -> GenericInstantiateContext<'a> { + GenericInstantiateContext { + db: self.db, + mapper, + self_type: self.self_type, + alias_stack: self.alias_stack.clone(), + } + } + + pub(super) fn enter_alias_stack( + &self, + alias_type_id: &LuaTypeDeclId, + ) -> Option> { + if self.alias_stack.len() >= MAX_ALIAS_STACK + || self.alias_stack.iter().any(|id| id == alias_type_id) + { + return None; + } + + let mut alias_stack = Vec::with_capacity(self.alias_stack.len() + 1); + alias_stack.extend(self.alias_stack.iter().cloned()); + alias_stack.push(alias_type_id.clone()); + Some(Arc::from(alias_stack)) + } + + pub(super) fn with_alias_stack( + &self, + alias_stack: Arc<[LuaTypeDeclId]>, + mapper: TypeMapper, + ) -> GenericInstantiateContext<'a> { + GenericInstantiateContext { + db: self.db, + mapper, + self_type: self.self_type, + alias_stack, + } + } +} + +impl GenericInstantiateFrame { + pub(super) fn with_policy(self, policy: UninferredTplPolicy) -> Self { + Self { policy, ..self } + } + + pub(super) fn should_preserve_tpl_ref(&self) -> bool { + self.policy == UninferredTplPolicy::PreserveTplRef + } + + pub(super) fn enter(self) -> Option { + if self.depth >= MAX_INSTANTIATION_DEPTH { + return None; + } + + Some(Self { + depth: self.depth + 1, + ..self + }) + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs index 2ae7fd3d7..60f188d3f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs @@ -28,7 +28,9 @@ use crate::{ infer_node_semantic_decl, }; -use super::{TypeSubstitutor, instantiate_type_generic}; +use super::{ + TypeMapper, TypeMapperValue, instantiate_type_generic, instantiate_type_generic_with_self, +}; pub fn infer_call_func_generic( db: &DbIndex, @@ -37,7 +39,7 @@ pub fn infer_call_func_generic( call_expr: LuaCallExpr, ) -> Result { let file_id = cache.get_file_id().clone(); - let (generic_tpls, contain_self) = collect_func_tpl_ids(func); + let (inference_slots, mapper_tpls, contain_self) = collect_func_generic_info(func); let origin_params = func.get_params(); let mut func_params: Vec<_> = origin_params @@ -50,11 +52,11 @@ pub fn infer_call_func_generic( .ok_or(InferFailReason::None)? .get_args() .collect::>(); - let mut substitutor = TypeSubstitutor::new(); + let mapper; { let mut context = InferenceContext::new(db, cache, Some(call_expr.clone())); - if !generic_tpls.is_empty() { - context.prepare_inference_slots(generic_tpls); + if !inference_slots.is_empty() { + context.prepare_inference_slots(inference_slots); if let Some(type_list) = call_expr.get_call_generic_type_list() { // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 @@ -72,15 +74,16 @@ pub fn infer_call_func_generic( } } - let func_generic_tpls = func_generic_tpls(func); - context.bridge_to_substitutor(&mut substitutor, func_generic_tpls.iter(), func.get_ret()); + mapper = context.fixing_mapper(mapper_tpls.iter(), func.get_ret()); } - if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { - substitutor.add_self_type(self_type); - } + let self_type = contain_self + .then(|| infer_self_type(db, cache, &call_expr)) + .flatten(); let func_ty = LuaType::DocFunction(func.clone().into()); - if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_ty, &substitutor) { + if let LuaType::DocFunction(f) = + instantiate_type_generic_with_self(db, &func_ty, &mapper, self_type.as_ref()) + { Ok(f.deref().clone()) } else { Ok(func.clone()) @@ -232,12 +235,11 @@ fn instantiate_callable_from_arg_types( return None; } - let mut callable_tpls = HashSet::new(); - callable.visit_nested_types(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { - callable_tpls.insert(generic_tpl.get_tpl_id()); - } - }); + let (_, callable_mapper_tpls, _) = collect_func_generic_info(callable); + let callable_tpls = callable_mapper_tpls + .iter() + .map(|tpl| tpl.get_tpl_id()) + .collect::>(); if callable_tpls.is_empty() { return Some(callable.clone()); } @@ -247,10 +249,7 @@ fn instantiate_callable_from_arg_types( .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); - let callable_generic_tpls = callable_generic_tpls(callable); - let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.prepare_inference_slots(callable_tpls.clone()); - { + let (non_fixing_mapper, mapper) = { let mut callable_context = InferenceContext::new(context.db, context.cache, context.call_expr.clone()); callable_context.prepare_inference_slots(callable_tpls.clone()); @@ -266,16 +265,25 @@ fn instantiate_callable_from_arg_types( { return None; } - callable_context.bridge_resolved_to_substitutor( - &mut callable_substitutor, - callable_generic_tpls.iter(), - callable.get_ret(), - ); - } + let non_fixing_mapper = + callable_context.non_fixing_mapper(callable_mapper_tpls.iter(), callable.get_ret()); + let mapper = + callable_context.fixing_mapper(callable_mapper_tpls.iter(), callable.get_ret()); + (non_fixing_mapper, mapper) + }; - let callable_ty = LuaType::DocFunction(callable.clone()); + let return_only_callable = LuaType::DocFunction( + LuaFunctionType::new( + callable.get_async_state(), + callable.is_colon_define(), + false, + Vec::new(), + callable.get_ret().clone(), + ) + .into(), + ); let instantiated = - match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { + match instantiate_type_generic(context.db, &return_only_callable, &non_fixing_mapper) { LuaType::DocFunction(func) => func, _ => callable.clone(), }; @@ -289,7 +297,11 @@ fn instantiate_callable_from_arg_types( } }); if tpl_ids.is_empty() { - return Some(instantiated); + let callable_ty = LuaType::DocFunction(callable.clone()); + return match instantiate_type_generic(context.db, &callable_ty, &mapper) { + LuaType::DocFunction(func) => Some(func), + _ => Some(instantiated), + }; } tpl_ids }; @@ -304,11 +316,12 @@ fn instantiate_callable_from_arg_types( return None; } + let mut mapper = mapper; for tpl_id in callback_return_tpls { - callable_substitutor.bind_type(tpl_id, LuaType::Unknown); + mapper = TypeMapper::prepend(tpl_id, LuaType::Unknown, Some(mapper)); } let callable_ty = LuaType::DocFunction(callable.clone()); - match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { + match instantiate_type_generic(context.db, &callable_ty, &mapper) { LuaType::DocFunction(func) => Some(func), _ => None, } @@ -366,43 +379,31 @@ fn collect_callback_return_tpls( callback_return_tpls } -fn collect_func_tpl_ids(func: &LuaFunctionType) -> (HashSet, bool) { - let mut generic_tpls = HashSet::new(); +fn collect_func_generic_info( + func: &LuaFunctionType, +) -> (HashSet, Vec>, bool) { + let mut inference_slots = HashSet::new(); + let mut mapper_tpls = Vec::new(); let mut contain_self = false; - - func.visit_nested_types(&mut |ty| match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_func_tpl_with_fallback_deps(generic_tpl, &mut generic_tpls); - } - LuaType::StrTplRef(str_tpl) => { - generic_tpls.insert(str_tpl.get_tpl_id()); - } - LuaType::SelfInfer => contain_self = true, - _ => {} - }); - - (generic_tpls, contain_self) -} - -fn func_generic_tpls(func: &LuaFunctionType) -> Vec> { - let mut generic_tpls = Vec::new(); func.visit_nested_types(&mut |ty| match ty { LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + collect_func_tpl_with_fallback_deps(generic_tpl, &mut inference_slots); if generic_tpl.get_tpl_id().is_func() - && !generic_tpls + && !mapper_tpls .iter() .any(|it: &Arc| it.get_tpl_id() == generic_tpl.get_tpl_id()) { - generic_tpls.push(generic_tpl.clone()); + mapper_tpls.push(generic_tpl.clone()); } } + LuaType::StrTplRef(str_tpl) => { + inference_slots.insert(str_tpl.get_tpl_id()); + } + LuaType::SelfInfer => contain_self = true, _ => {} }); - generic_tpls -} -fn callable_generic_tpls(callable: &LuaFunctionType) -> Vec> { - func_generic_tpls(callable) + (inference_slots, mapper_tpls, contain_self) } fn collect_func_tpl_with_fallback_deps( @@ -636,11 +637,21 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { LuaType::Def(id) | LuaType::Ref(id) => { if let Some(generic) = db.get_type_index().get_generic_params(id) { let mut params = Vec::with_capacity(generic.len()); - let mut substitutor = TypeSubstitutor::new(); + let mut prefix_sources = Vec::with_capacity(generic.len()); + let mut prefix_targets = Vec::with_capacity(generic.len()); for (i, generic_param) in generic.iter().enumerate() { let tpl_id = GenericTplId::Type(i as u32); - let param = build_self_generic_arg(db, generic_param, &substitutor); - substitutor.bind_type(tpl_id, param.clone()); + let mapper = TypeMapper::from_values( + prefix_sources.clone(), + prefix_targets + .iter() + .cloned() + .map(TypeMapperValue::type_value) + .collect(), + ); + let param = build_self_generic_arg(db, generic_param, &mapper); + prefix_sources.push(tpl_id); + prefix_targets.push(param.clone()); params.push(param); } let generic = LuaGenericType::new(id.clone(), params); @@ -655,7 +666,7 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { fn build_self_generic_arg( db: &DbIndex, generic_param: &GenericParam, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> LuaType { let Some(arg) = generic_param .default_type @@ -665,7 +676,7 @@ fn build_self_generic_arg( return LuaType::Unknown; }; - instantiate_type_generic(db, arg, substitutor) + instantiate_type_generic(db, arg, mapper) } pub fn infer_self_type( diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index 32a0d1f7c..3da8f3231 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -7,10 +7,9 @@ use crate::{ semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; +use super::{GenericInstantiateContext, GenericInstantiateFrame}; use super::{get_default_constructor, instantiate_type_generic_inner}; -use crate::semantic::generic::type_substitutor::{ - GenericInstantiateContext, GenericInstantiateFrame, TplBinding, -}; +use crate::semantic::generic::{TypeMapper, get_mapped_value}; #[derive(Debug, Clone, Copy)] enum InferVariance { @@ -149,9 +148,9 @@ fn instantiate_conditional_residual( let instantiate_branch = |branch: &LuaType| { if branch.any_type(|ty| match ty { LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { - context.substitutor.get(tpl.get_tpl_id()).is_some() + get_mapped_value(tpl.get_tpl_id(), &context.mapper).is_some() } - LuaType::SelfInfer => context.substitutor.get_self_type().is_some(), + LuaType::SelfInfer => context.self_type().is_some(), _ => false, }) { instantiate_type_generic_inner(context, frame, branch) @@ -191,7 +190,7 @@ fn instantiate_distributed_conditional( } _ => return None, }; - let raw_checked_type = context.substitutor.get_raw_type(tpl_id)?; + let raw_checked_type = get_mapped_value(tpl_id, &context.mapper)?.raw_type()?; if raw_checked_type.is_never() { return Some(LuaType::Never); @@ -208,9 +207,8 @@ fn instantiate_distributed_conditional( }; let mut result = LuaType::Never; for member in members { - let mut member_substitutor = context.substitutor.clone(); - member_substitutor.bind(tpl_id, TplBinding::ReplaceConstType(member)); - let member_context = context.with_substitutor(&member_substitutor); + let member_mapper = TypeMapper::prepend(tpl_id, member, Some(context.mapper.clone())); + let member_context = context.with_mapper(member_mapper); let member_result = instantiate_conditional_once(&member_context, frame, conditional); result = TypeOps::Union.apply(context.db, &result, &member_result); } @@ -228,11 +226,11 @@ fn instantiate_true_branch( return instantiate_type_generic_inner(context, frame, conditional.get_true_type()); } - let mut true_substitutor = context.substitutor.clone(); + let mut true_mapper = context.mapper.clone(); for (tpl_id, ty) in infer_assignments { - true_substitutor.bind(tpl_id, TplBinding::ConditionalInferType(ty)); + true_mapper = TypeMapper::prepend(tpl_id, ty, Some(true_mapper)); } - let true_context = context.with_substitutor(&true_substitutor); + let true_context = context.with_mapper(true_mapper); instantiate_type_generic_inner(&true_context, frame, conditional.get_true_type()) } @@ -752,8 +750,10 @@ fn instantiate_conditional_operand( let mut result = instantiate_type_generic_inner(context, frame, operand); if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); - if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { - result = raw.clone(); + if let Some(raw) = + get_mapped_value(tpl_id, &context.mapper).and_then(|value| value.raw_type()) + { + result = raw; } else if checked && result.is_never() { result = LuaType::Never; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs index 21a9ad912..bbf9f25c4 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs @@ -10,7 +10,7 @@ use super::{ GenericInstantiateContext, GenericInstantiateFrame, instantiate_special_generic, instantiate_type_generic_inner, key_type_to_member_key, }; -use crate::semantic::generic::type_substitutor::TplBinding; +use crate::semantic::generic::TypeMapper; pub(super) fn instantiate_mapped_type( context: &GenericInstantiateContext, @@ -41,15 +41,14 @@ pub(super) fn instantiate_mapped_type( let mut field_indices: HashMap = HashMap::with_capacity(key_count); let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::with_capacity(key_count); let mut index_access: Vec<(LuaType, LuaType)> = Vec::with_capacity(key_count); - let mut local_substitutor = context.substitutor.clone(); - for key_ty in key_domain.keys { if !visited.insert(key_ty.clone()) { continue; } - local_substitutor.bind(mapped.param.0, TplBinding::ReplaceConstType(key_ty.clone())); - let local_context = context.with_substitutor(&local_substitutor); + let local_mapper = + TypeMapper::prepend(mapped.param.0, key_ty.clone(), Some(context.mapper.clone())); + let local_context = context.with_mapper(local_mapper); let mut value_ty = instantiate_type_generic_inner(&local_context, frame, &mapped.value); if mapped.is_optional { value_ty = TypeOps::Union.apply(context.db, &value_ty, &LuaType::Nil); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 80f10166f..04a0d0975 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -10,10 +10,8 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{ - GenericInstantiateContext, GenericInstantiateFrame, SubstitutorValue, TypeSubstitutor, - instantiate_type_generic_inner, -}; +use super::{GenericInstantiateContext, GenericInstantiateFrame, instantiate_type_generic_inner}; +use crate::semantic::generic::get_mapped_value; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, @@ -85,7 +83,7 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let key = resolve_literal_operand(operand_exprs.get(1), context.substitutor) + let key = resolve_literal_operand(operand_exprs.get(1), context) .unwrap_or_else(|| operands[1].clone()); instantiate_rawget_call(context.db, &operands[0], &key) @@ -95,7 +93,7 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let key = resolve_literal_operand(operand_exprs.get(1), context.substitutor) + let key = resolve_literal_operand(operand_exprs.get(1), context) .unwrap_or_else(|| operands[1].clone()); instantiate_index_call(context.db, &operands[0], &key) @@ -158,11 +156,12 @@ fn instantiate_merge_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { fn resolve_literal_operand( operand: Option<&LuaType>, - substitutor: &TypeSubstitutor, + context: &GenericInstantiateContext, ) -> Option { match operand { Some(LuaType::TplRef(tpl_ref)) | Some(LuaType::ConstTplRef(tpl_ref)) => { - substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned() + get_mapped_value(tpl_ref.get_tpl_id(), &context.mapper) + .and_then(|value| value.raw_type()) } _ => None, } @@ -254,28 +253,10 @@ fn resolve_unpack_operands( return instantiate_type_generic_inner(context, frame, operand); } let raw = match operand { - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => context - .substitutor - .get(tpl_ref.get_tpl_id()) - .and_then(|value| match value { - SubstitutorValue::None => None, - SubstitutorValue::Type { value, .. } => Some(value.raw().clone()), - SubstitutorValue::MultiTypes { values, .. } => Some(LuaType::Variadic( - VariadicType::Multi( - values.iter().map(|value| value.raw().clone()).collect(), - ) - .into(), - )), - SubstitutorValue::Params(params) => Some( - params - .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() - .unwrap_or(LuaType::Unknown), - ), - SubstitutorValue::MultiBase(base) => Some(base.clone()), - }), + LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { + get_mapped_value(tpl_ref.get_tpl_id(), &context.mapper) + .and_then(|value| value.raw_type()) + } _ => None, }; raw.unwrap_or_else(|| instantiate_type_generic_inner(context, frame, operand)) diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index aef69e547..b343f0441 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,4 +1,5 @@ mod complete_generic_args; +mod context; mod infer_call_func_generic; mod inference_widening; mod instantiate_conditional_generic; @@ -17,13 +18,11 @@ use crate::{ }, }; -use super::type_substitutor::{ - GenericInstantiateContext, GenericInstantiateFrame, SubstitutorValue, TypeSubstitutor, - UninferredTplPolicy, -}; +use super::{TypeMapper, TypeMapperValue, get_mapped_value}; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; +use context::{GenericInstantiateContext, GenericInstantiateFrame, UninferredTplPolicy}; pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infer_self_type}; pub(in crate::semantic::generic) use inference_widening::{ is_primitive_or_literal_type, regularize_tpl_candidate_type, widen_tpl_candidate_type, @@ -31,12 +30,17 @@ pub(in crate::semantic::generic) use inference_widening::{ use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_inner; pub use instantiate_special_generic::get_keyof_members; -pub fn instantiate_type_generic( +pub fn instantiate_type_generic(db: &DbIndex, ty: &LuaType, mapper: &TypeMapper) -> LuaType { + instantiate_type_generic_with_self(db, ty, mapper, None) +} + +pub fn instantiate_type_generic_with_self( db: &DbIndex, ty: &LuaType, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, + self_type: Option<&LuaType>, ) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); + let context = GenericInstantiateContext::new(db, mapper, self_type); let frame = context.root_frame(); match ty { LuaType::DocFunction(doc_func) => instantiate_doc_function(&context, frame, doc_func), @@ -49,25 +53,60 @@ pub(super) fn instantiate_type_generic_inner( frame: GenericInstantiateFrame, ty: &LuaType, ) -> LuaType { + if is_simple_instantiate_leaf(ty) { + return ty.clone(); + } + let Some(frame) = frame.enter() else { return ty.clone(); }; match ty { - LuaType::Array(array_type) => instantiate_array(context, frame, array_type.get_base()), - LuaType::Tuple(tuple) => instantiate_tuple(context, frame, tuple), - LuaType::DocFunction(doc_func) => instantiate_doc_function( - context, - frame.with_policy(UninferredTplPolicy::PreserveTplRef), - doc_func, - ), - LuaType::Object(object) => instantiate_object(context, frame, object), - LuaType::Union(union) => instantiate_union(context, frame, union), + LuaType::Array(array_type) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_array(context, frame, array_type.get_base()) + } + LuaType::Tuple(tuple) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_tuple(context, frame, tuple) + } + LuaType::DocFunction(doc_func) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_doc_function( + context, + frame.with_policy(UninferredTplPolicy::PreserveTplRef), + doc_func, + ) + } + LuaType::Object(object) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_object(context, frame, object) + } + LuaType::Union(union) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_union(context, frame, union) + } LuaType::Intersection(intersection) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } instantiate_intersection(context, frame, intersection) } LuaType::Generic(generic) => instantiate_generic(context, frame, generic), LuaType::TableGeneric(table_params) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } instantiate_table_generic(context, frame, table_params) } LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context, frame), @@ -78,13 +117,16 @@ pub(super) fn instantiate_type_generic_inner( } LuaType::Variadic(variadic) => instantiate_variadic_type(context, frame, variadic), LuaType::SelfInfer => { - if let Some(typ) = context.substitutor.get_self_type() { + if let Some(typ) = context.self_type() { typ.clone() } else { LuaType::SelfInfer } } LuaType::TypeGuard(guard) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } let inner = instantiate_type_generic_inner(context, frame, guard.deref()); LuaType::TypeGuard(inner.into()) } @@ -96,6 +138,32 @@ pub(super) fn instantiate_type_generic_inner( } } +fn requires_instantiation_walk(ty: &LuaType) -> bool { + match ty { + LuaType::TplRef(_) + | LuaType::StrTplRef(_) + | LuaType::ConstTplRef(_) + | LuaType::SelfInfer + | LuaType::Generic(_) + | LuaType::Signature(_) + | LuaType::Call(_) + | LuaType::Conditional(_) + | LuaType::Mapped(_) => true, + LuaType::Array(array_type) => requires_instantiation_walk(array_type.get_base()), + LuaType::Tuple(tuple) => tuple.any_type(requires_instantiation_walk), + LuaType::DocFunction(doc_func) => doc_func.any_type(requires_instantiation_walk), + LuaType::Object(object) => object.any_type(requires_instantiation_walk), + LuaType::Union(union) => union.any_type(requires_instantiation_walk), + LuaType::Intersection(intersection) => intersection.any_type(requires_instantiation_walk), + LuaType::TableGeneric(table_params) => table_params.iter().any(requires_instantiation_walk), + LuaType::Variadic(variadic) => variadic.any_type(requires_instantiation_walk), + LuaType::TypeGuard(guard) => requires_instantiation_walk(guard.deref()), + LuaType::MultiLineUnion(inner) => inner.any_type(requires_instantiation_walk), + LuaType::DocAttribute(attr) => attr.any_type(requires_instantiation_walk), + _ => false, + } +} + fn instantiate_types<'a, I>( context: &GenericInstantiateContext, frame: GenericInstantiateFrame, @@ -149,24 +217,20 @@ fn instantiate_tuple( match inner.deref() { VariadicType::Base(base) => { if let LuaType::TplRef(tpl) = base { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => new_types + TypeMapperValue::None => new_types .push(instantiate_uninferred_tpl_fallback(tpl, context, frame)), - SubstitutorValue::Params(params) => { + TypeMapperValue::Params(params) => { for (_, ty) in params { - new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); + new_types.push(ty.unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::MultiTypes { values, .. } => { - new_types.extend( - values.iter().map(|value| value.resolved().clone()), - ); - } - SubstitutorValue::Type { value, .. } => { - new_types.push(value.resolved().clone()) + TypeMapperValue::MultiTypes(values) => { + new_types.extend(values); } - SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), + TypeMapperValue::Type(value) => new_types.push(value), + TypeMapperValue::MultiBase(base) => new_types.push(base), } } else { new_types.push(LuaType::Variadic(inner.clone())); @@ -190,6 +254,10 @@ fn instantiate_doc_function( frame: GenericInstantiateFrame, doc_func: &LuaFunctionType, ) -> LuaType { + if !doc_func.any_type(requires_instantiation_walk) { + return LuaType::DocFunction(doc_func.clone().into()); + } + let tpl_func_params = doc_func.get_params(); let tpl_ret = doc_func.get_ret(); let async_state = doc_func.get_async_state(); @@ -207,15 +275,14 @@ fn instantiate_doc_function( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => match base { LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { + TypeMapperValue::None => { let ty = instantiate_uninferred_tpl_fallback(tpl, context, frame); new_params.push((origin_param.0.clone(), Some(ty))); } - SubstitutorValue::Type { value, .. } => { - let resolved_type = value.resolved().clone(); + TypeMapperValue::Type(resolved_type) => { // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple @@ -241,16 +308,15 @@ fn instantiate_doc_function( )), )); } - SubstitutorValue::Params(params) => { - for param in params.iter() { - new_params.push(param.clone()); + TypeMapperValue::Params(params) => { + for param in params { + new_params.push(param); } } - SubstitutorValue::MultiTypes { values, .. } => { - for (i, value) in values.iter().enumerate() { + TypeMapperValue::MultiTypes(values) => { + for (i, value) in values.into_iter().enumerate() { let param_name = format!("var{}", i); - new_params - .push((param_name, Some(value.resolved().clone()))); + new_params.push((param_name, Some(value))); } } _ => { @@ -335,6 +401,10 @@ fn instantiate_object( frame: GenericInstantiateFrame, object: &LuaObjectType, ) -> LuaType { + if !object.any_type(requires_instantiation_walk) { + return LuaType::Object(object.clone().into()); + } + let new_fields = object .get_fields() .iter() @@ -356,6 +426,10 @@ fn instantiate_union( frame: GenericInstantiateFrame, union: &LuaUnionType, ) -> LuaType { + if !union.any_type(requires_instantiation_walk) { + return LuaType::Union(union.clone().into()); + } + LuaType::from_vec(instantiate_types(context, frame, union.into_vec().iter())) } @@ -364,6 +438,10 @@ fn instantiate_intersection( frame: GenericInstantiateFrame, intersection: &LuaIntersectionType, ) -> LuaType { + if !intersection.any_type(requires_instantiation_walk) { + return LuaType::Intersection(intersection.clone().into()); + } + LuaType::Intersection( LuaIntersectionType::new(instantiate_types( context, @@ -389,16 +467,15 @@ fn instantiate_generic( return LuaType::Unknown; }; - if !context.substitutor.check_recursion(&type_decl_id) - && let Some(type_decl) = context.db.get_type_index().get_type_decl(&type_decl_id) + if let Some(type_decl) = context.db.get_type_index().get_type_decl(&type_decl_id) && type_decl.is_alias() { - let Some(alias_context) = context.enter_alias(&type_decl_id) else { + let Some(alias_stack) = context.enter_alias_stack(&type_decl_id) else { return LuaType::Generic(LuaGenericType::new(type_decl_id, new_params).into()); }; - let new_substitutor = - TypeSubstitutor::from_alias(context.db, new_params.clone(), type_decl_id.clone()); - let alias_context = alias_context.with_substitutor(&new_substitutor); + let alias_mapper = TypeMapper::from_alias(context.db, new_params.clone(), &type_decl_id); + let alias_mapper = TypeMapper::merge(Some(alias_mapper), context.mapper.clone()); + let alias_context = context.with_alias_stack(alias_stack, alias_mapper); if let Some(origin) = type_decl.get_alias_ref() { return instantiate_type_generic_inner(&alias_context, frame, origin); } @@ -412,6 +489,10 @@ fn instantiate_table_generic( frame: GenericInstantiateFrame, table_params: &[LuaType], ) -> LuaType { + if !table_params.iter().any(requires_instantiation_walk) { + return LuaType::TableGeneric(table_params.to_vec().into()); + } + LuaType::TableGeneric(instantiate_types(context, frame, table_params.iter()).into()) } @@ -442,34 +523,24 @@ fn instantiate_tpl_ref( context: &GenericInstantiateContext, frame: GenericInstantiateFrame, ) -> LuaType { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { + TypeMapperValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context, frame); } - SubstitutorValue::Type { value, .. } => { - return value.resolved().clone(); + TypeMapperValue::Type(value) => { + return value; } - SubstitutorValue::MultiTypes { values, .. } => { - return LuaType::Variadic( - VariadicType::Multi( - values - .iter() - .map(|value| value.resolved().clone()) - .collect(), - ) - .into(), - ); + TypeMapperValue::MultiTypes(values) => { + return LuaType::Variadic(VariadicType::Multi(values).into()); } - SubstitutorValue::Params(params) => { + TypeMapperValue::Params(params) => { return params .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() + .and_then(|(_, ty)| ty.clone()) .unwrap_or(LuaType::Unknown); } - SubstitutorValue::MultiBase(base) => return base.clone(), + TypeMapperValue::MultiBase(base) => return base, } } @@ -481,34 +552,24 @@ fn instantiate_const_tpl_ref( context: &GenericInstantiateContext, frame: GenericInstantiateFrame, ) -> LuaType { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { + TypeMapperValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context, frame); } - SubstitutorValue::Type { value, .. } => { - return value.resolved().clone(); + TypeMapperValue::Type(value) => { + return value; } - SubstitutorValue::MultiTypes { values, .. } => { - return LuaType::Variadic( - VariadicType::Multi( - values - .iter() - .map(|value| value.resolved().clone()) - .collect(), - ) - .into(), - ); + TypeMapperValue::MultiTypes(values) => { + return LuaType::Variadic(VariadicType::Multi(values).into()); } - SubstitutorValue::Params(params) => { + TypeMapperValue::Params(params) => { return params .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() + .and_then(|(_, ty)| ty.clone()) .unwrap_or(LuaType::Unknown); } - SubstitutorValue::MultiBase(base) => return base.clone(), + TypeMapperValue::MultiBase(base) => return base, } } @@ -549,12 +610,16 @@ fn instantiate_variadic_type( frame: GenericInstantiateFrame, variadic: &VariadicType, ) -> LuaType { + if !variadic.any_type(requires_instantiation_walk) { + return LuaType::Variadic(variadic.clone().into()); + } + match variadic { VariadicType::Base(base) => match base { LuaType::TplRef(tpl) => { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { + TypeMapperValue::None => { let fallback = instantiate_uninferred_tpl_fallback(tpl, context, frame); return match fallback { LuaType::Variadic(_) | LuaType::Never => fallback, @@ -562,8 +627,7 @@ fn instantiate_variadic_type( _ => LuaType::Variadic(VariadicType::Base(fallback).into()), }; } - SubstitutorValue::Type { value, .. } => { - let resolved_type = value.resolved().clone(); + TypeMapperValue::Type(resolved_type) => { if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never @@ -572,26 +636,18 @@ fn instantiate_variadic_type( } return LuaType::Variadic(VariadicType::Base(resolved_type).into()); } - SubstitutorValue::MultiTypes { values, .. } => { - return LuaType::Variadic( - VariadicType::Multi( - values - .iter() - .map(|value| value.resolved().clone()) - .collect(), - ) - .into(), - ); + TypeMapperValue::MultiTypes(values) => { + return LuaType::Variadic(VariadicType::Multi(values).into()); } - SubstitutorValue::Params(params) => { + TypeMapperValue::Params(params) => { let types = params - .iter() - .filter_map(|(_, ty)| ty.clone()) + .into_iter() + .filter_map(|(_, ty)| ty) .collect::>(); return LuaType::Variadic(VariadicType::Multi(types).into()); } - SubstitutorValue::MultiBase(base) => { - return LuaType::Variadic(VariadicType::Base(base.clone()).into()); + TypeMapperValue::MultiBase(base) => { + return LuaType::Variadic(VariadicType::Base(base).into()); } } } else { @@ -648,3 +704,36 @@ pub(super) fn get_default_constructor(db: &DbIndex, decl_id: &LuaTypeDeclId) -> let operator = db.get_operator_index().get_operator(id)?; Some(operator.get_operator_func(db)) } + +fn is_simple_instantiate_leaf(ty: &LuaType) -> bool { + matches!( + ty, + LuaType::Unknown + | LuaType::Any + | LuaType::Nil + | LuaType::Table + | LuaType::Userdata + | LuaType::Function + | LuaType::Thread + | LuaType::Boolean + | LuaType::String + | LuaType::Integer + | LuaType::Number + | LuaType::Io + | LuaType::Global + | LuaType::Never + | LuaType::BooleanConst(_) + | LuaType::StringConst(_) + | LuaType::IntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::TableConst(_) + | LuaType::Ref(_) + | LuaType::Def(_) + | LuaType::DocStringConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::DocBooleanConst(_) + | LuaType::Namespace(_) + | LuaType::Language(_) + | LuaType::ModuleRef(_) + ) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 5cd76d080..a617ae2f0 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -2,7 +2,7 @@ mod call_constraint; mod inference; mod instantiate_type; mod test; -mod type_substitutor; +mod type_mapper; use std::sync::Arc; @@ -19,7 +19,8 @@ pub(in crate::semantic::generic) use inference::{ }; pub use instantiate_type::*; use rowan::NodeOrToken; -pub use type_substitutor::TypeSubstitutor; +pub(in crate::semantic::generic) use type_mapper::get_mapped_value; +pub use type_mapper::{TypeMapper, TypeMapperValue}; use crate::DbIndex; use crate::GenericTpl; @@ -65,17 +66,12 @@ pub fn instantiate_doc_function_by_arg_types( InferencePriority::Normal, )?; - let mut substitutor = TypeSubstitutor::new(); let generic_tpls = collect_doc_function_generic_tpls(doc_function); - context.bridge_to_substitutor( - &mut substitutor, - generic_tpls.iter(), - doc_function.get_ret(), - ); + let mapper = context.fixing_mapper(generic_tpls.iter(), doc_function.get_ret()); let doc_function_ty = LuaType::DocFunction(doc_function.clone()); Ok( - match instantiate_type_generic(db, &doc_function_ty, &substitutor) { + match instantiate_type_generic(db, &doc_function_ty, &mapper) { LuaType::DocFunction(func) => func, _ => doc_function.clone(), }, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index ebc0fc52d..0581483fc 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -1,7 +1,16 @@ #[cfg(test)] mod test { + use hashbrown::HashMap; + use std::sync::Arc; + use super::super::instantiate_type::{regularize_tpl_candidate_type, widen_tpl_candidate_type}; - use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; + use crate::{ + AsyncState, DbIndex, DiagnosticCode, GenericTpl, GenericTplId, LuaArrayType, + LuaFunctionType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaTupleStatus, + LuaTupleType, LuaType, LuaUnionType, TypeMapper, TypeMapperValue, VariadicType, + VirtualWorkspace, + }; + use smol_str::SmolStr; #[test] fn test_variadic_func() { @@ -300,6 +309,372 @@ result = { )); } + #[test] + fn test_inference_mapper_fallback_and_explicit_precedence() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T = string + ---@return T + local function defaulted() + end + + ---@generic T: integer + ---@return T + local function constrained() + end + + ---@generic T + ---@param value T + ---@return T + local function explicit(value) + end + + default_result = defaulted() + constraint_result = constrained() + explicit_result = explicit--[[@]](1) + + "#, + ); + + assert_eq!(ws.expr_ty("default_result"), ws.ty("string")); + assert_eq!(ws.expr_ty("constraint_result"), ws.ty("integer")); + assert_eq!(ws.expr_ty("explicit_result"), ws.ty("string")); + } + + #[test] + fn test_mapper_reducer_reuses_alias_mapped_and_function_shapes() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias MapperBox { value: T, list: T[] } + ---@alias Copy { [K in keyof T]: T[K]; } + + ---@generic T + ---@param value T + ---@return MapperBox + local function box(value) + end + + ---@generic T + ---@param value T + ---@return Copy + local function copy(value) + end + + ---@generic T + ---@param value T + ---@return fun(next: T): T + local function make_id(value) + end + + box_result = box("name") + box_value = box_result.value + box_list_item = box_result.list[1] + + copied = copy({ name = "a", count = 1 }) + copied_name = copied.name + copied_count = copied.count + + made = make_id(1) + made_ret = made(2) + "#, + ); + + assert_eq!(ws.expr_ty("box_value"), ws.ty("string")); + assert_eq!(ws.expr_ty("box_list_item"), ws.ty("string?")); + assert_eq!(ws.expr_ty("copied_name"), ws.ty("string")); + assert_eq!(ws.expr_ty("copied_count"), ws.ty("integer")); + assert_eq!(ws.expr_ty("made_ret"), ws.ty("integer")); + } + + #[test] + fn test_structural_instantiate_fast_path_preserves_plain_shapes() { + let db = DbIndex::new(); + let empty_mapper = TypeMapper::empty(); + + let plain_array = LuaType::Array(LuaArrayType::from_base_type(LuaType::Number).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_array, + &empty_mapper + ), + plain_array + ); + + let plain_tuple = LuaType::Tuple( + LuaTupleType::new( + vec![LuaType::Number, LuaType::String], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_tuple, + &empty_mapper + ), + plain_tuple + ); + + let plain_object = LuaType::Object( + LuaObjectType::new_with_fields( + HashMap::from([ + (LuaMemberKey::Name(SmolStr::new("name")), LuaType::String), + (LuaMemberKey::Name(SmolStr::new("count")), LuaType::Number), + ]), + Vec::new(), + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_object, + &empty_mapper + ), + plain_object + ); + + let plain_union = + LuaType::Union(LuaUnionType::from_vec(vec![LuaType::Number, LuaType::String]).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_union, + &empty_mapper + ), + plain_union + ); + + let plain_intersection = LuaType::Intersection( + LuaIntersectionType::new(vec![LuaType::Number, LuaType::String]).into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_intersection, + &empty_mapper + ), + plain_intersection + ); + + let plain_table_generic = + LuaType::TableGeneric(Arc::new(vec![LuaType::Number, LuaType::String])); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_table_generic, + &empty_mapper + ), + plain_table_generic + ); + + let plain_variadic = LuaType::Variadic(VariadicType::Base(LuaType::Number).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_variadic, + &empty_mapper + ), + plain_variadic + ); + + let plain_doc_function = LuaType::DocFunction( + LuaFunctionType::new( + AsyncState::None, + false, + false, + vec![("value".to_string(), Some(LuaType::Number))], + LuaType::String, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_doc_function, + &empty_mapper + ), + plain_doc_function + ); + + let plain_type_guard = LuaType::TypeGuard(Arc::new(LuaType::Number)); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_type_guard, + &empty_mapper + ), + plain_type_guard + ); + } + + #[test] + fn test_structural_instantiate_fast_path_instantiates_template_children() { + let db = DbIndex::new(); + let mapper = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let tpl = LuaType::TplRef(Arc::new(GenericTpl::new( + GenericTplId::Func(0), + SmolStr::new("T0").into(), + None, + None, + ))); + + let templated_array = LuaType::Array(LuaArrayType::from_base_type(tpl.clone()).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_array, + &mapper + ), + LuaType::Array(LuaArrayType::from_base_type(LuaType::String).into()) + ); + + let templated_tuple = LuaType::Tuple( + LuaTupleType::new( + vec![LuaType::Number, tpl.clone()], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_tuple, + &mapper + ), + LuaType::Tuple( + LuaTupleType::new( + vec![LuaType::Number, LuaType::String], + LuaTupleStatus::DocResolve, + ) + .into() + ) + ); + + let templated_object = LuaType::Object( + LuaObjectType::new_with_fields( + HashMap::from([ + (LuaMemberKey::Name(SmolStr::new("name")), tpl.clone()), + (LuaMemberKey::Name(SmolStr::new("count")), LuaType::Number), + ]), + Vec::new(), + ) + .into(), + ); + let expected_object = LuaType::Object( + LuaObjectType::new_with_fields( + HashMap::from([ + (LuaMemberKey::Name(SmolStr::new("name")), LuaType::String), + (LuaMemberKey::Name(SmolStr::new("count")), LuaType::Number), + ]), + Vec::new(), + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_object, + &mapper + ), + expected_object + ); + + let templated_union = + LuaType::Union(LuaUnionType::from_vec(vec![LuaType::Number, tpl.clone()]).into()); + let expected_union = + LuaType::Union(LuaUnionType::from_vec(vec![LuaType::Number, LuaType::String]).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_union, + &mapper + ), + expected_union + ); + + let templated_intersection = LuaType::Intersection( + LuaIntersectionType::new(vec![LuaType::Number, tpl.clone()]).into(), + ); + let expected_intersection = LuaType::Intersection( + LuaIntersectionType::new(vec![LuaType::Number, LuaType::String]).into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_intersection, + &mapper + ), + expected_intersection + ); + + let templated_table_generic = LuaType::TableGeneric(Arc::new(vec![tpl.clone()])); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_table_generic, + &mapper + ), + LuaType::TableGeneric(Arc::new(vec![LuaType::String])) + ); + + let templated_variadic = LuaType::Variadic(VariadicType::Base(tpl.clone()).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_variadic, + &mapper + ), + LuaType::Variadic(VariadicType::Base(LuaType::String).into()) + ); + + let templated_doc_function = LuaType::DocFunction( + LuaFunctionType::new( + AsyncState::None, + false, + false, + vec![("value".to_string(), Some(tpl.clone()))], + tpl.clone(), + ) + .into(), + ); + let expected_doc_function = LuaType::DocFunction( + LuaFunctionType::new( + AsyncState::None, + false, + false, + vec![("value".to_string(), Some(LuaType::String))], + LuaType::String, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_doc_function, + &mapper + ), + expected_doc_function + ); + + let templated_type_guard = LuaType::TypeGuard(Arc::new(tpl.clone())); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_type_guard, + &mapper + ), + LuaType::TypeGuard(Arc::new(LuaType::String)) + ); + } + #[test] fn test_123() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs new file mode 100644 index 000000000..cab2afd50 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs @@ -0,0 +1,364 @@ +use std::{cell::RefCell, rc::Rc}; + +use hashbrown::HashMap; + +use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId, VariadicType}; + +#[derive(Debug, Clone)] +pub enum TypeMapper { + Simple { + source: GenericTplId, + target: TypeMapperValue, + }, + Array { + mappings: Rc>, + }, + InferenceFallback { + indices: Rc>, + targets: Rc>>, + }, + Merged { + layers: Rc<[TypeMapper]>, + }, +} + +impl TypeMapper { + pub fn empty() -> Self { + Self::from_values(Vec::new(), Vec::new()) + } + + pub fn from_values(sources: Vec, targets: Vec) -> Self { + if sources.len() == 1 { + return TypeMapper::Simple { + source: sources[0], + target: targets + .into_iter() + .next() + .unwrap_or(TypeMapperValue::Type(LuaType::Any)), + }; + } + + let mut mappings = HashMap::with_capacity(sources.len().min(targets.len())); + for (source, target) in sources.into_iter().zip(targets) { + mappings.entry(source).or_insert(target); + } + TypeMapper::Array { + mappings: Rc::new(mappings), + } + } + + pub fn from_type_array(type_array: Vec) -> Self { + let sources = (0..type_array.len()) + .map(|idx| GenericTplId::Type(idx as u32)) + .collect::>(); + let targets = type_array + .into_iter() + .map(TypeMapperValue::type_value) + .collect(); + Self::from_values(sources, targets) + } + + pub fn from_uninferred(sources: Vec) -> Self { + let targets = sources + .iter() + .map(|_| TypeMapperValue::None) + .collect::>(); + Self::from_values(sources, targets) + } + + pub fn from_alias( + db: &DbIndex, + type_array: Vec, + alias_type_id: &LuaTypeDeclId, + ) -> Self { + let params = db.get_type_index().get_generic_params(alias_type_id); + let sources = type_array + .iter() + .enumerate() + .map(|(i, _)| { + params + .and_then(|params| params.get(i)) + .and_then(|param| param.tpl_id) + .unwrap_or(GenericTplId::Type(i as u32)) + }) + .collect::>(); + let targets = type_array + .into_iter() + .map(TypeMapperValue::type_value) + .collect(); + Self::from_values(sources, targets) + } + + fn collect_layers(mapper: TypeMapper, layers: &mut Vec) { + match mapper { + TypeMapper::Merged { + layers: nested_layers, + } => { + for layer in nested_layers.iter().cloned() { + Self::collect_layers(layer, layers); + } + } + other => layers.push(other), + } + } + + pub fn from_inference_fallback( + indices: Rc>, + targets: Rc>>, + ) -> Self { + TypeMapper::InferenceFallback { indices, targets } + } + + pub fn merge(mapper1: Option, mapper2: TypeMapper) -> Self { + let mut layers = Vec::new(); + if let Some(mapper1) = mapper1 { + Self::collect_layers(mapper1, &mut layers); + } + Self::collect_layers(mapper2, &mut layers); + match layers.len() { + 0 => Self::empty(), + 1 => layers.remove(0), + _ => TypeMapper::Merged { + layers: Rc::from(layers.into_boxed_slice()), + }, + } + } + + pub fn prepend(source: GenericTplId, target: LuaType, mapper: Option) -> Self { + let unary = TypeMapper::Simple { + source, + target: TypeMapperValue::type_value(target), + }; + match mapper { + Some(mapper) => Self::merge(Some(unary), mapper), + None => unary, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TypeMapperValue { + None, + Type(LuaType), + Params(Vec<(String, Option)>), + MultiTypes(Vec), + MultiBase(LuaType), +} + +impl TypeMapperValue { + pub fn type_value(ty: LuaType) -> Self { + TypeMapperValue::Type(into_ref_type(ty)) + } + + pub fn params_value(params: Vec<(String, Option)>) -> Self { + TypeMapperValue::Params( + params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(), + ) + } + + pub fn raw_type(&self) -> Option { + match self { + TypeMapperValue::Type(ty) => Some(ty.clone()), + TypeMapperValue::Params(params) => params + .first() + .and_then(|(_, ty)| ty.clone()) + .or(Some(LuaType::Unknown)), + TypeMapperValue::MultiTypes(types) => { + Some(LuaType::Variadic(VariadicType::Multi(types.clone()).into())) + } + TypeMapperValue::MultiBase(base) => Some(base.clone()), + TypeMapperValue::None => None, + } + } + + fn direct_tpl_id(&self) -> Option { + match self { + TypeMapperValue::Type(LuaType::TplRef(tpl)) + | TypeMapperValue::Type(LuaType::ConstTplRef(tpl)) => Some(tpl.get_tpl_id()), + _ => None, + } + } +} + +pub(in crate::semantic::generic) fn get_mapped_value( + tpl_id: GenericTplId, + mapper: &TypeMapper, +) -> Option { + match mapper { + TypeMapper::Simple { source, target } => { + if *source == tpl_id { + Some(target.clone()) + } else { + None + } + } + TypeMapper::Array { mappings } => mappings.get(&tpl_id).cloned(), + TypeMapper::InferenceFallback { indices, targets } => { + let index = *indices.get(&tpl_id)?; + targets.borrow().get(index).cloned() + } + TypeMapper::Merged { layers } => { + let mut current_tpl_id = tpl_id; + let mut current_value: Option = None; + let mut has_direct_value = false; + + for layer in layers.iter() { + if let Some(value) = get_mapped_value(current_tpl_id, layer) { + if let Some(mapped_tpl_id) = value.direct_tpl_id() { + current_tpl_id = mapped_tpl_id; + current_value = Some(value); + has_direct_value = true; + continue; + } + + if matches!(value, TypeMapperValue::None) { + if has_direct_value { + continue; + } + return Some(TypeMapperValue::None); + } + + return Some(value); + } + } + + current_value + } + } +} + +fn into_ref_type(ty: LuaType) -> LuaType { + match ty { + LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), + _ => ty, + } +} + +#[cfg(test)] +mod tests { + use std::{cell::RefCell, rc::Rc, sync::Arc}; + + use smol_str::SmolStr; + + use super::*; + use crate::{GenericTpl, GenericTplId, LuaArrayType}; + + fn tpl_ref(idx: u32) -> LuaType { + LuaType::TplRef(Arc::new(GenericTpl::new( + GenericTplId::Func(idx), + SmolStr::new(format!("T{}", idx)).into(), + None, + None, + ))) + } + + // 合并后的 mapper 只读直接结果, 不在查询阶段深度展开结构里的模板. + #[test] + fn merged_does_not_deep_instantiate_mapped_result() { + let first = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::Array( + LuaArrayType::from_base_type(tpl_ref(1)).into(), + ))], + ); + let second = TypeMapper::from_values( + vec![GenericTplId::Func(1)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let mapper = TypeMapper::merge(Some(first), second); + + let mapped = + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()); + + assert_eq!( + mapped, + Some(LuaType::Array( + LuaArrayType::from_base_type(tpl_ref(1)).into() + )) + ); + } + + // 直接的 TplRef 链要继续追到后层 mapper. + #[test] + fn merged_resolves_direct_tpl_ref_through_later_mapper() { + let first = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(tpl_ref(1))], + ); + let second = TypeMapper::from_values( + vec![GenericTplId::Func(1)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let mapper = TypeMapper::merge(Some(first), second); + + let mapped = + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()); + + assert_eq!(mapped, Some(LuaType::String)); + } + + // 未推断的槽位要保留为 None, 不能和“没有映射”混掉. + #[test] + fn inference_fallback_keeps_unresolved_slots_as_none() { + let mut indices = HashMap::new(); + indices.insert(GenericTplId::Func(0), 0); + indices.insert(GenericTplId::Func(1), 1); + let targets = Rc::new(RefCell::new(vec![ + TypeMapperValue::None, + TypeMapperValue::None, + ])); + let mapper = TypeMapper::from_inference_fallback(Rc::new(indices), targets); + + assert_eq!( + get_mapped_value(GenericTplId::Func(1), &mapper), + Some(TypeMapperValue::None) + ); + } + + // 后层显式 None 不能抹掉前层已经建立的 TplRef 链. + #[test] + fn merged_keeps_direct_tpl_ref_when_later_mapper_is_explicit_none() { + let first = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(tpl_ref(1))], + ); + let second = + TypeMapper::from_values(vec![GenericTplId::Func(1)], vec![TypeMapperValue::None]); + let mapper = TypeMapper::merge(Some(first), second); + + let mapped = + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()); + + assert_eq!(mapped, Some(tpl_ref(1))); + } + + // 长链别名要能一路追到最终具体类型. + #[test] + fn long_mapper_chain_resolves_transitively() { + let mut mapper = None; + for idx in 0..64 { + let source = GenericTplId::Func(idx); + let target = if idx == 63 { + LuaType::String + } else { + tpl_ref(idx + 1) + }; + let layer = + TypeMapper::from_values(vec![source], vec![TypeMapperValue::type_value(target)]); + mapper = Some(TypeMapper::merge(mapper, layer)); + } + + let mapper = mapper.expect("mapper"); + assert_eq!( + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()), + Some(LuaType::String) + ); + assert_eq!( + get_mapped_value(GenericTplId::Func(31), &mapper).and_then(|value| value.raw_type()), + Some(LuaType::String) + ); + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs deleted file mode 100644 index b212f8fc7..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ /dev/null @@ -1,335 +0,0 @@ -use hashbrown::{HashMap, HashSet}; - -use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId}; -use std::sync::Arc; - -const MAX_INSTANTIATION_DEPTH: usize = 128; -const MAX_ALIAS_STACK: usize = 32; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum UninferredTplPolicy { - /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. - Fallback, - /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. - PreserveTplRef, -} - -pub(in crate::semantic::generic) enum TplBinding { - FinalizedType(LuaType), - ReplaceConstType(LuaType), - ConditionalInferType(LuaType), - VariadicParams(Vec<(String, Option)>), - InferredMultiTypes(Vec), - VariadicBase(LuaType), -} - -#[derive(Debug)] -pub struct GenericInstantiateContext<'a> { - pub db: &'a DbIndex, - pub substitutor: &'a TypeSubstitutor, - alias_stack: Arc<[LuaTypeDeclId]>, -} - -#[derive(Debug, Clone, Copy)] -pub(super) struct GenericInstantiateFrame { - policy: UninferredTplPolicy, - depth: usize, -} - -impl<'a> GenericInstantiateContext<'a> { - pub fn new(db: &'a DbIndex, substitutor: &'a TypeSubstitutor) -> Self { - Self { - db, - substitutor, - alias_stack: Arc::from([]), - } - } - - pub(super) fn root_frame(&self) -> GenericInstantiateFrame { - GenericInstantiateFrame { - policy: UninferredTplPolicy::Fallback, - depth: 0, - } - } - - pub(super) fn with_substitutor<'b>( - &'b self, - substitutor: &'b TypeSubstitutor, - ) -> GenericInstantiateContext<'b> { - GenericInstantiateContext { - db: self.db, - substitutor, - alias_stack: self.alias_stack.clone(), - } - } - - pub(super) fn enter_alias( - &self, - alias_type_id: &LuaTypeDeclId, - ) -> Option> { - if self.alias_stack.len() >= MAX_ALIAS_STACK - || self.alias_stack.iter().any(|id| id == alias_type_id) - { - return None; - } - - let mut alias_stack = Vec::with_capacity(self.alias_stack.len() + 1); - alias_stack.extend(self.alias_stack.iter().cloned()); - alias_stack.push(alias_type_id.clone()); - Some(GenericInstantiateContext { - db: self.db, - substitutor: self.substitutor, - alias_stack: Arc::from(alias_stack), - }) - } -} - -impl GenericInstantiateFrame { - pub(super) fn with_policy(self, policy: UninferredTplPolicy) -> Self { - Self { policy, ..self } - } - - pub fn should_preserve_tpl_ref(&self) -> bool { - self.policy == UninferredTplPolicy::PreserveTplRef - } - - pub(super) fn enter(self) -> Option { - if self.depth >= MAX_INSTANTIATION_DEPTH { - return None; - } - - Some(Self { - depth: self.depth + 1, - ..self - }) - } -} - -#[derive(Debug, Clone)] -pub struct TypeSubstitutor { - tpl_replace_map: HashMap, - alias_type_id: Option, - self_type: Option, -} - -impl Default for TypeSubstitutor { - fn default() -> Self { - Self::new() - } -} - -impl TypeSubstitutor { - pub fn new() -> Self { - Self { - tpl_replace_map: HashMap::new(), - alias_type_id: None, - self_type: None, - } - } - - pub fn from_type_array(type_array: Vec) -> Self { - let mut tpl_replace_map = HashMap::new(); - for (i, ty) in type_array.into_iter().enumerate() { - tpl_replace_map.insert( - GenericTplId::Type(i as u32), - SubstitutorValue::Type { - value: SubstitutorTypeValue::new(ty), - }, - ); - } - Self { - tpl_replace_map, - alias_type_id: None, - self_type: None, - } - } - - pub fn from_alias( - db: &DbIndex, - type_array: Vec, - alias_type_id: LuaTypeDeclId, - ) -> Self { - let params = db.get_type_index().get_generic_params(&alias_type_id); - - let mut tpl_replace_map = HashMap::new(); - for (i, ty) in type_array.into_iter().enumerate() { - let tpl_id = params - .and_then(|params| params.get(i)) - .and_then(|param| param.tpl_id) - .unwrap_or(GenericTplId::Type(i as u32)); - tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type { - value: SubstitutorTypeValue::new(ty), - }, - ); - } - - Self { - tpl_replace_map, - alias_type_id: Some(alias_type_id), - self_type: None, - } - } - - pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { - for tpl_id in tpl_ids { - // conditional infer id 只属于条件类型内部匹配, 不参与普通调用/类型泛型推导. - if tpl_id.is_conditional_infer() { - continue; - } - - self.tpl_replace_map - .entry(tpl_id) - .or_insert(SubstitutorValue::None); - } - } - - pub fn has_unresolved_inference_slots(&self) -> bool { - self.tpl_replace_map - .values() - .any(|value| matches!(value, SubstitutorValue::None)) - } - - pub fn bind_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { - self.bind(tpl_id, TplBinding::FinalizedType(replace_type)); - } - - pub(in crate::semantic::generic) fn bind(&mut self, tpl_id: GenericTplId, binding: TplBinding) { - match binding { - TplBinding::ConditionalInferType(replace_type) => { - // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. - if !tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type { - value: SubstitutorTypeValue::new(replace_type), - }, - ); - } - TplBinding::ReplaceConstType(replace_type) => { - if tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type { - value: SubstitutorTypeValue::new(replace_type), - }, - ); - } - binding => { - // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. - if tpl_id.is_conditional_infer() || !self.can_bind(tpl_id) { - return; - } - - let value = match binding { - TplBinding::FinalizedType(replace_type) => SubstitutorValue::Type { - value: SubstitutorTypeValue::new(replace_type), - }, - TplBinding::VariadicParams(params) => { - let params = params - .into_iter() - .map(|(name, ty)| (name, ty.map(into_ref_type))) - .collect(); - SubstitutorValue::Params(params) - } - TplBinding::InferredMultiTypes(types) => SubstitutorValue::MultiTypes { - values: types.into_iter().map(SubstitutorTypeValue::new).collect(), - }, - TplBinding::VariadicBase(type_base) => SubstitutorValue::MultiBase(type_base), - TplBinding::ReplaceConstType(_) | TplBinding::ConditionalInferType(_) => { - unreachable!("handled before regular binding") - } - }; - - self.tpl_replace_map.insert(tpl_id, value); - } - } - } - - fn can_bind(&self, tpl_id: GenericTplId) -> bool { - if let Some(value) = self.tpl_replace_map.get(&tpl_id) { - return value.is_none(); - } - - true - } - - pub(super) fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { - self.tpl_replace_map.get(&tpl_id) - } - - pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { - match self.tpl_replace_map.get(&tpl_id) { - Some(SubstitutorValue::Type { value, .. }) => Some(value.raw()), - _ => None, - } - } - - pub fn check_recursion(&self, type_id: &LuaTypeDeclId) -> bool { - if let Some(alias_type_id) = &self.alias_type_id - && alias_type_id == type_id - { - return true; - } - - false - } - - pub fn add_self_type(&mut self, self_type: LuaType) { - self.self_type = Some(self_type); - } - - pub fn get_self_type(&self) -> Option<&LuaType> { - self.self_type.as_ref() - } -} - -#[derive(Debug, Clone)] -pub struct SubstitutorTypeValue { - raw: LuaType, -} - -impl SubstitutorTypeValue { - fn new(raw: LuaType) -> Self { - Self { - raw: into_ref_type(raw), - } - } - - pub fn raw(&self) -> &LuaType { - &self.raw - } - - pub(super) fn resolved(&self) -> &LuaType { - &self.raw - } -} - -#[derive(Debug, Clone)] -pub(super) enum SubstitutorValue { - None, - Type { value: SubstitutorTypeValue }, - Params(Vec<(String, Option)>), - MultiTypes { values: Vec }, - MultiBase(LuaType), -} - -impl SubstitutorValue { - pub fn is_none(&self) -> bool { - matches!(self, SubstitutorValue::None) - } -} - -fn into_ref_type(ty: LuaType) -> LuaType { - match ty { - LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), - _ => ty, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 3c01665f1..b24e659cc 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -17,11 +17,14 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{TypeSubstitutor, get_tpl_ref_extend_type}, + generic::{TypeMapper, get_tpl_ref_extend_type}, infer::narrow::get_type_at_call_expr_inline_cast, }, }; -use crate::{build_self_type, infer_call_func_generic, infer_self_type, semantic::infer_expr}; +use crate::{ + build_self_type, infer_call_func_generic, infer_self_type, instantiate_type_generic_with_self, + semantic::infer_expr, +}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -266,11 +269,10 @@ fn filter_callable_overloads_by_call_args( } let has_tpls = !callable_tpls.is_empty(); - let mut substitutor = TypeSubstitutor::new(); - substitutor.prepare_inference_slots(callable_tpls); + let mapper = TypeMapper::from_uninferred(callable_tpls.into_iter().collect()); let match_func = if has_tpls { let func_ty = LuaType::DocFunction(func.clone()); - match instantiate_type_generic(db, &func_ty, &substitutor) { + match instantiate_type_generic(db, &func_ty, &mapper) { LuaType::DocFunction(doc_func) => doc_func, _ => func.clone(), } @@ -362,12 +364,11 @@ fn infer_type_doc_function( let result = infer_call_func_generic(db, cache, &f, call_expr.clone())?; overloads.push(Arc::new(result)); } else if f.contain_self() { - let mut substitutor = TypeSubstitutor::new(); + let mapper = TypeMapper::empty(); let self_type = build_self_type(db, call_expr_type); - substitutor.add_self_type(self_type); let func_ty = LuaType::DocFunction(f.clone()); if let LuaType::DocFunction(f) = - instantiate_type_generic(db, &func_ty, &substitutor) + instantiate_type_generic_with_self(db, &func_ty, &mapper, Some(&self_type)) { overloads.push(f); } @@ -404,7 +405,7 @@ fn infer_generic_type_doc_function( let type_id = generic.get_base_type_id(); infer_guard.check(&type_id)?; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); let type_decl = db .get_type_index() @@ -412,7 +413,7 @@ fn infer_generic_type_doc_function( .ok_or_else(|| InferFailReason::UnResolveTypeDecl(type_id.clone()))?; if type_decl.is_alias() { let origin_type = type_decl - .get_alias_origin(db, Some(&substitutor)) + .get_alias_origin(db, Some(&mapper)) .ok_or(InferFailReason::None)?; return infer_call_expr_func( db, @@ -438,7 +439,7 @@ fn infer_generic_type_doc_function( let func = operator.get_operator_func(db); match func { LuaType::DocFunction(_) => { - let new_f = instantiate_type_generic(db, &func, &substitutor); + let new_f = instantiate_type_generic(db, &func, &mapper); if let LuaType::DocFunction(f) = new_f { overloads.push(f.clone()); } @@ -453,7 +454,7 @@ fn infer_generic_type_doc_function( } let typ = LuaType::DocFunction(signature.to_call_operator_func_type()); - let new_f = instantiate_type_generic(db, &typ, &substitutor); + let new_f = instantiate_type_generic(db, &typ, &mapper); if let LuaType::DocFunction(f) = new_f { overloads.push(f.clone()); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 8c11e5f24..5e49a7bee 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -18,7 +18,7 @@ use crate::{ enum_variable_is_param, get_keyof_members, get_tpl_ref_extend_type, semantic::{ InferGuard, - generic::{TypeSubstitutor, instantiate_type_generic}, + generic::{TypeMapper, instantiate_type_generic}, infer::{ VarRefId, infer_index::infer_array::{ @@ -663,7 +663,7 @@ fn infer_generic_members_from_super_generics( db: &DbIndex, cache: &mut LuaInferCache, type_decl_id: &LuaTypeDeclId, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> Option { @@ -677,7 +677,7 @@ fn infer_generic_members_from_super_generics( let type_decl_id = type_decl.get_id(); if let Some(super_types) = type_index.get_super_types(&type_decl_id) { super_types.iter().find_map(|super_type| { - let super_type = instantiate_type_generic(db, super_type, substitutor); + let super_type = instantiate_type_generic(db, super_type, mapper); infer_member_by_lookup(db, cache, &super_type, lookup, &infer_guard.fork()).ok() }) } else { @@ -700,16 +700,16 @@ fn infer_generic_member( return Err(InferFailReason::None); }; let generic_params = generic_type.get_params(); - let substitutor = if type_decl.is_alias() { - TypeSubstitutor::from_alias(db, generic_params.clone(), base_type_decl_id.clone()) + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, generic_params.clone(), base_type_decl_id) } else { - TypeSubstitutor::from_type_array(generic_params.clone()) + TypeMapper::from_type_array(generic_params.clone()) }; if type_decl.is_alias() && let Some(origin_type) = type_decl .get_alias_ref() - .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + .map(|origin| instantiate_type_generic(db, origin, &mapper)) { return infer_member_by_lookup(db, cache, &origin_type, lookup, &infer_guard.fork()); } @@ -718,7 +718,7 @@ fn infer_generic_member( db, cache, base_type_decl_id, - &substitutor, + &mapper, lookup, infer_guard, ); @@ -727,7 +727,7 @@ fn infer_generic_member( } let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; - return Ok(instantiate_type_generic(db, &member_type, &substitutor)); + return Ok(instantiate_type_generic(db, &member_type, &mapper)); } infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard) @@ -999,15 +999,15 @@ fn infer_member_by_index_generic( let type_decl = type_index .get_type_decl(&type_decl_id) .ok_or(InferFailReason::None)?; - let substitutor = if type_decl.is_alias() { - TypeSubstitutor::from_alias(db, generic_params.clone(), type_decl_id.clone()) + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, generic_params.clone(), &type_decl_id) } else { - TypeSubstitutor::from_type_array(generic_params.clone()) + TypeMapper::from_type_array(generic_params.clone()) }; if type_decl.is_alias() { if let Some(origin_type) = type_decl .get_alias_ref() - .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + .map(|origin| instantiate_type_generic(db, origin, &mapper)) { return infer_member_by_operator_key_type( db, @@ -1029,9 +1029,9 @@ fn infer_member_by_index_generic( .get_operator(index_operator_id) .ok_or(InferFailReason::None)?; let operand = index_operator.get_operand(db); - let instianted_operand = instantiate_type_generic(db, &operand, &substitutor); + let instianted_operand = instantiate_type_generic(db, &operand, &mapper); let return_type = - instantiate_type_generic(db, &index_operator.get_result(db)?, &substitutor); + instantiate_type_generic(db, &index_operator.get_result(db)?, &mapper); let result = infer_index_metamethod_by_key_type(db, key_type, &instianted_operand, &return_type); @@ -1054,7 +1054,7 @@ fn infer_member_by_index_generic( let result = infer_member_by_operator_key_type( db, cache, - &instantiate_type_generic(db, &super_type, &substitutor), + &instantiate_type_generic(db, &super_type, &mapper), key_type, &infer_guard.fork(), ); diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_index.rs b/crates/emmylua_code_analysis/src/semantic/member/find_index.rs index 20f8d4422..091031564 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_index.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_index.rs @@ -3,11 +3,8 @@ use hashbrown::{HashMap, HashSet}; use crate::{ DbIndex, InFiled, InferGuardRef, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, - LuaType, LuaTypeDeclId, LuaUnionType, TypeOps, - semantic::{ - InferGuard, - generic::{TypeSubstitutor, instantiate_type_generic}, - }, + LuaType, LuaTypeDeclId, LuaUnionType, TypeMapper, TypeOps, + semantic::{InferGuard, generic::instantiate_type_generic}, }; use super::{FindMembersResult, LuaMemberInfo, intersect_member_types}; @@ -310,13 +307,13 @@ fn find_index_generic( }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index.get_type_decl(&type_decl_id)?; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { - let instantiated_type = instantiate_type_generic(db, &origin_type, &substitutor); + if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&mapper)) { + let instantiated_type = instantiate_type_generic(db, &origin_type, &mapper); return find_index_operations_guard(db, &instantiated_type, infer_guard); } return None; @@ -332,11 +329,11 @@ fn find_index_generic( for index_operator_id in index_operator_ids { if let Some(index_operator) = operator_index.get_operator(index_operator_id) { let operand = index_operator.get_operand(db); - let instantiated_operand = instantiate_type_generic(db, &operand, &substitutor); + let instantiated_operand = instantiate_type_generic(db, &operand, &mapper); if let Ok(return_type) = index_operator.get_result(db) { let instantiated_return_type = - instantiate_type_generic(db, &return_type, &substitutor); + instantiate_type_generic(db, &return_type, &mapper); members.push(LuaMemberInfo { property_owner_id: None, @@ -353,7 +350,7 @@ fn find_index_generic( // Find index operations in super types if let Some(supers) = type_index.get_super_types(&type_decl_id) { for super_type in supers { - let instantiated_super = instantiate_type_generic(db, &super_type, &substitutor); + let instantiated_super = instantiate_type_generic(db, &super_type, &mapper); if let Some(super_members) = find_index_operations_guard(db, &instantiated_super, infer_guard) { diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs index 58284dc3d..a166687bf 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs @@ -8,7 +8,7 @@ use crate::{ LuaTypeDeclId, LuaUnionType, semantic::{ InferGuard, - generic::{TypeSubstitutor, instantiate_type_generic}, + generic::{TypeMapper, instantiate_type_generic}, }, }; @@ -84,7 +84,7 @@ pub fn find_members_with_key_in_scope( struct FindMembersContext { file_id: FileId, infer_guard: InferGuardRef, - substitutor: Option, + mapper: Option, } impl FindMembersContext { @@ -92,14 +92,14 @@ impl FindMembersContext { Self { file_id, infer_guard, - substitutor: None, + mapper: None, } } - fn with_substitutor(&self, substitutor: TypeSubstitutor) -> Self { + fn with_mapper(&self, mapper: TypeMapper) -> Self { Self { file_id: self.file_id, infer_guard: self.infer_guard.clone(), - substitutor: Some(substitutor), + mapper: Some(mapper), } } @@ -107,13 +107,13 @@ impl FindMembersContext { Self { file_id: self.file_id, infer_guard: self.infer_guard.fork(), - substitutor: self.substitutor.clone(), + mapper: self.mapper.clone(), } } fn instantiate_type(&self, db: &DbIndex, ty: &LuaType) -> LuaType { - if let Some(substitutor) = &self.substitutor { - instantiate_type_generic(db, ty, substitutor) + if let Some(mapper) = &self.mapper { + instantiate_type_generic(db, ty, mapper) } else { ty.clone() } @@ -496,23 +496,23 @@ fn find_generic_members( .map(|param| ctx.instantiate_type(db, param)) .collect(); let type_decl = db.get_type_index().get_type_decl(&base_ref_id)?; - let substitutor = if type_decl.is_alias() { - TypeSubstitutor::from_alias(db, instantiated_params, base_ref_id.clone()) + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, instantiated_params, &base_ref_id) } else { - TypeSubstitutor::from_type_array(instantiated_params) + TypeMapper::from_type_array(instantiated_params) }; - let ctx_with_substitutor = ctx.with_substitutor(substitutor.clone()); + let ctx_with_mapper = ctx.with_mapper(mapper.clone()); if let Some(origin) = type_decl .get_alias_ref() - .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + .map(|origin| instantiate_type_generic(db, origin, &mapper)) { - return find_members_guard(db, &origin, &ctx_with_substitutor, filter); + return find_members_guard(db, &origin, &ctx_with_mapper, filter); } find_members_guard( db, &LuaType::Ref(base_ref_id.clone()), - &ctx_with_substitutor, + &ctx_with_mapper, filter, ) } diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs index 7edaffb2a..e322cd9e3 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs @@ -4,9 +4,8 @@ use smol_str::SmolStr; use crate::{ DbIndex, InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaMemberKey, - LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, - check_type_compact, - semantic::generic::{TypeSubstitutor, instantiate_type_generic}, + LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeDeclId, TypeMapper, TypeOps, + check_type_compact, semantic::generic::instantiate_type_generic, }; use super::{RawGetMemberTypeResult, get_buildin_type_map_type_id}; @@ -218,17 +217,17 @@ fn infer_generic_raw_member_type( .get_type_index() .get_type_decl(&base_ref_id) .ok_or(InferFailReason::None)?; - let substitutor = if type_decl.is_alias() { - TypeSubstitutor::from_alias(db, generic_params.clone(), base_ref_id.clone()) + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, generic_params.clone(), &base_ref_id) } else { - TypeSubstitutor::from_type_array(generic_params.clone()) + TypeMapper::from_type_array(generic_params.clone()) }; - if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin) = type_decl.get_alias_origin(db, Some(&mapper)) { return infer_raw_member_type_guard(db, &origin, member_key, infer_guard); } let base_ref_type = LuaType::Ref(base_ref_id.clone()); let result = infer_raw_member_type_guard(db, &base_ref_type, member_key, infer_guard)?; - Ok(instantiate_type_generic(db, &result, &substitutor)) + Ok(instantiate_type_generic(db, &result, &mapper)) } diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs index fd5f568af..19d139395 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs @@ -4,7 +4,7 @@ use hashbrown::HashSet; use crate::db_index::{DbIndex, LuaFunctionType, LuaType, LuaTypeDeclId}; -use super::super::{generic::TypeSubstitutor, infer::InferFailReason}; +use super::super::{generic::TypeMapper, infer::InferFailReason}; pub(crate) fn collect_callable_overload_groups( db: &DbIndex, @@ -43,15 +43,13 @@ fn collect_callable_overload_groups_inner( if !visiting_aliases.insert(type_id.clone()) { return Ok(()); } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let mapper = TypeMapper::from_type_array(generic.get_params().to_vec()); let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { visiting_aliases.remove(&type_id); return Ok(()); }; - let result = if let Some(origin_type) = - type_decl.get_alias_origin(db, Some(&substitutor)) - { + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&mapper)) { collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) } else { Ok(()) diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs index 763258b1f..4f7421789 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs @@ -13,8 +13,7 @@ use table_generic_check::check_table_generic_type_compact; use tuple_type_check::check_tuple_type_compact; use crate::{ - LuaType, LuaUnionType, TypeSubstitutor, - semantic::type_check::type_check_context::TypeCheckContext, + LuaType, LuaUnionType, TypeMapper, semantic::type_check::type_check_context::TypeCheckContext, }; use super::{ @@ -37,12 +36,9 @@ pub fn check_complex_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - context.db, - generic.get_params().clone(), - base_id.clone(), - ); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let mapper = + TypeMapper::from_alias(context.db, generic.get_params().clone(), &base_id); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&mapper)) { return check_general_type_compact( context, source, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs index e7af2fa1c..796677345 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs @@ -1,5 +1,5 @@ use crate::{ - TypeSubstitutor, + TypeMapper, db_index::{LuaFunctionType, LuaOperatorMetaMethod, LuaSignatureId, LuaType, LuaTypeDeclId}, semantic::type_check::type_check_context::TypeCheckContext, }; @@ -23,12 +23,9 @@ pub fn check_doc_func_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - context.db, - generic.get_params().clone(), - base_id.clone(), - ); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let mapper = + TypeMapper::from_alias(context.db, generic.get_params().clone(), &base_id); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&mapper)) { return check_general_type_compact( context, &LuaType::DocFunction(source_func.clone().into()), diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs index 85bec1bb7..e12e0c949 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - LuaGenericType, LuaMemberOwner, LuaType, LuaTypeCache, LuaTypeDeclId, RenderLevel, - TypeSubstitutor, complete_type_generic_args_in_type, humanize_type, instantiate_type_generic, + LuaGenericType, LuaMemberOwner, LuaType, LuaTypeCache, LuaTypeDeclId, RenderLevel, TypeMapper, + complete_type_generic_args_in_type, humanize_type, instantiate_type_generic, semantic::{member::find_members, type_check::type_check_context::TypeCheckContext}, }; @@ -24,13 +24,10 @@ pub fn check_generic_type_compact( .get_type_decl(&source_generic.get_base_type_id()) && decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - context.db, - source_generic.get_params().clone(), - base_id.clone(), - ); + let mapper = + TypeMapper::from_alias(context.db, source_generic.get_params().clone(), &base_id); if let Some(alias_ref) = decl.get_alias_ref() { - let alias_origin = instantiate_type_generic(context.db, alias_ref, &substitutor); + let alias_origin = instantiate_type_generic(context.db, alias_ref, &mapper); return check_general_type_compact( context, &alias_origin, @@ -65,10 +62,9 @@ pub fn check_generic_type_compact( { for mut super_type in supers { if super_type.contain_tpl() { - let substitutor = - TypeSubstitutor::from_type_array(compact_generic.get_params().clone()); - super_type = - instantiate_type_generic(context.db, &super_type, &substitutor); + let mapper = + TypeMapper::from_type_array(compact_generic.get_params().clone()); + super_type = instantiate_type_generic(context.db, &super_type, &mapper); } let result = check_generic_type_compact( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 70865eec2..eb15bfbbe 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use crate::{ - DbIndex, LuaType, LuaTypeDeclId, TypeSubstitutor, VariadicType, + DbIndex, LuaType, LuaTypeDeclId, TypeMapper, VariadicType, semantic::type_check::{ is_sub_type_of, type_check_context::{TypeCheckCheckLevel, TypeCheckContext}, @@ -310,14 +310,9 @@ pub fn check_simple_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - context.db, - generic.get_params().clone(), - base_id.clone(), - ); - if let Some(alias_origin) = - decl.get_alias_origin(context.db, Some(&substitutor)) - { + let mapper = + TypeMapper::from_alias(context.db, generic.get_params().clone(), &base_id); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&mapper)) { return check_general_type_compact( context, source, diff --git a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs index 0e9050f4b..a9c77a058 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs @@ -2,7 +2,7 @@ use emmylua_code_analysis::{ DbIndex, GenericTplId, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, - RenderLevel, SemanticDeclLevel, TypeSubstitutor, build_call_constraint_context, get_real_type, + RenderLevel, SemanticDeclLevel, TypeMapper, build_call_constraint_context, get_real_type, instantiate_type_generic, normalize_constraint_type, }; use emmylua_parser::{ @@ -340,16 +340,16 @@ fn infer_call_arg_list( param_idx += 1; } } - let constraint_substitutor = build_call_constraint_context(&builder.semantic_model, &call_expr) - .map(|ctx| ctx.substitutor); - let substitutor = constraint_substitutor.as_ref(); + let constraint_mapper = + build_call_constraint_context(&builder.semantic_model, &call_expr).map(|ctx| ctx.mapper); + let mapper = constraint_mapper.as_ref(); let typ = call_expr_func .get_params() .get(param_idx)? .1 .clone() .unwrap_or(LuaType::Unknown); - let typ = resolve_param_type(builder, typ, substitutor); + let typ = resolve_param_type(builder, typ, mapper); let mut types = Vec::new(); types.push(typ); push_function_overloads_param( @@ -357,7 +357,7 @@ fn infer_call_arg_list( &call_expr, call_expr_func.get_params(), param_idx, - substitutor, + mapper, &mut types, ); Some(types.into_iter().unique().collect()) // 需要去重 @@ -366,22 +366,22 @@ fn infer_call_arg_list( fn resolve_param_type( builder: &CompletionBuilder, mut typ: LuaType, - substitutor: Option<&TypeSubstitutor>, + mapper: Option<&TypeMapper>, ) -> LuaType { let db = builder.semantic_model.get_db(); - if let Some(substitutor) = substitutor { - typ = apply_substitutor_to_type(db, typ, substitutor); + if let Some(mapper) = mapper { + typ = apply_mapper_to_type(db, typ, mapper); } normalize_constraint_type(db, typ) } -fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubstitutor) -> LuaType { +fn apply_mapper_to_type(db: &DbIndex, typ: LuaType, mapper: &TypeMapper) -> LuaType { if let LuaType::Call(alias_call) = &typ { if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf { let operands = alias_call .get_operands() .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) + .map(|operand| instantiate_type_generic(db, operand, mapper)) .collect::>(); return LuaType::Call(Arc::new(LuaAliasCallType::new( alias_call.get_call_kind(), @@ -389,16 +389,16 @@ fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubst ))); } } - if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, substitutor) { + if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, mapper) { return alias_call; } - instantiate_type_generic(db, &typ, substitutor) + instantiate_type_generic(db, &typ, mapper) } fn rebuild_keyof_alias_call( db: &DbIndex, original_type: &LuaType, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> Option { let tpl = match original_type { LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl, @@ -415,7 +415,7 @@ fn rebuild_keyof_alias_call( let operands = alias_call .get_operands() .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) + .map(|operand| instantiate_type_generic(db, operand, mapper)) .collect::>(); Some(LuaType::Call(Arc::new(LuaAliasCallType::new( alias_call.get_call_kind(), @@ -428,7 +428,7 @@ fn push_function_overloads_param( call_expr: &LuaCallExpr, call_params: &[(String, Option)], param_idx: usize, - substitutor: Option<&TypeSubstitutor>, + mapper: Option<&TypeMapper>, types: &mut Vec, ) -> Option<()> { let member_index = builder.semantic_model.get_db().get_member_index(); @@ -508,7 +508,7 @@ fn push_function_overloads_param( // 添加匹配的参数类型 if let Some(param_type) = overload_params.get(param_idx).and_then(|p| p.1.clone()) { - let param_type = resolve_param_type(builder, param_type, substitutor); + let param_type = resolve_param_type(builder, param_type, mapper); types.push(param_type); } } diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index c8f3e4818..7e046bb05 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -2,8 +2,8 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, - LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, infer_call_func_generic, + LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, TypeMapper, + VariadicType, humanize_type, infer_call_expr_func, infer_call_func_generic, instantiate_type_generic, try_extract_signature_id_from_field, }; @@ -188,8 +188,8 @@ fn build_function_define_hover( _ => None, }; - if let Some(substitutor) = &builder.substitutor { - if let Some(lua_func) = hover_instantiate_function_type(db, &typ, substitutor) { + if let Some(mapper) = &builder.mapper { + if let Some(lua_func) = hover_instantiate_function_type(db, &typ, mapper) { typ = LuaType::DocFunction(lua_func); } } @@ -696,14 +696,14 @@ fn function_member_is_field(db: &DbIndex, semantic_decls: &[(LuaSemanticDeclId, fn hover_instantiate_function_type( db: &DbIndex, typ: &LuaType, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> Option> { if !typ.contain_tpl() { return None; } match typ { LuaType::DocFunction(_) => { - if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, substitutor) { + if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, mapper) { Some(f) } else { None diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index 61b81de1f..642b6c2e6 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ - GenericTplId, LuaCompilation, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, - RenderLevel, SemanticModel, TypeSubstitutor, + LuaCompilation, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, RenderLevel, + SemanticModel, TypeMapper, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLocalName, LuaLocalStat, LuaSyntaxKind, LuaSyntaxToken, @@ -34,8 +34,8 @@ pub struct HoverBuilder<'a> { pub detail_render_level: RenderLevel, pub is_completion: bool, - // 默认的泛型替换器 - pub substitutor: Option, + // 默认的泛型 mapper + pub mapper: Option, } impl<'a> HoverBuilder<'a> { @@ -52,8 +52,8 @@ impl<'a> HoverBuilder<'a> { RenderLevel::Detailed }; - let substitutor = if let Some(token) = token.clone() { - infer_substitutor_base_type(semantic_model, token) + let mapper = if let Some(token) = token.clone() { + infer_mapper_base_type(semantic_model, token) } else { None }; @@ -70,7 +70,7 @@ impl<'a> HoverBuilder<'a> { type_expansion: None, tag_content: None, detail_render_level, - substitutor, + mapper, } } @@ -293,11 +293,11 @@ impl<'a> HoverBuilder<'a> { } } -// 推断基础泛型替换器 -fn infer_substitutor_base_type( +// 推断基础泛型 mapper +fn infer_mapper_base_type( semantic_model: &SemanticModel, trigger_token: LuaSyntaxToken, -) -> Option { +) -> Option { let parent = trigger_token.parent()?; match parent.kind().into() { LuaSyntaxKind::LocalName => { @@ -312,7 +312,7 @@ fn infer_substitutor_base_type( for (index, name) in local_name_list.iter().enumerate() { if target_local_name == *name { let value_expr = value_expr_list.get(index)?; - return substitutor_form_expr(semantic_model, value_expr); + return mapper_from_expr(semantic_model, value_expr); } } } @@ -325,20 +325,13 @@ fn infer_substitutor_base_type( None } -pub fn substitutor_form_expr( - semantic_model: &SemanticModel, - expr: &LuaExpr, -) -> Option { +pub fn mapper_from_expr(semantic_model: &SemanticModel, expr: &LuaExpr) -> Option { if let LuaExpr::IndexExpr(index_expr) = expr { let prefix_type = semantic_model .infer_expr(index_expr.get_prefix_expr()?) .ok()?; - let mut substitutor = TypeSubstitutor::new(); if let LuaType::Generic(generic) = prefix_type { - for (i, param) in generic.get_params().iter().enumerate() { - substitutor.bind_type(GenericTplId::Type(i as u32), param.clone()); - } - return Some(substitutor); + return Some(TypeMapper::from_type_array(generic.get_params().clone())); } else { return None; } diff --git a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs index 4e5cef58f..b2bcfd1bb 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs @@ -1,7 +1,7 @@ use emmylua_code_analysis::{ DbIndex, InFiled, LuaCompilation, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignatureId, LuaType, - LuaTypeDeclId, RenderLevel, SemanticModel, TypeSubstitutor, + LuaTypeDeclId, RenderLevel, SemanticModel, TypeMapper, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind}; use lsp_types::{ @@ -566,8 +566,8 @@ fn build_generic_signature_help( let generic_params = generic.get_params(); let type_decl_id = generic.get_base_type_id_ref(); if let Some(type_decl) = db.get_type_index().get_type_decl(type_decl_id) { - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); - if let Some(LuaType::DocFunction(f)) = type_decl.get_alias_origin(db, Some(&substitutor)) { + let mapper = TypeMapper::from_type_array(generic_params.clone()); + if let Some(LuaType::DocFunction(f)) = type_decl.get_alias_origin(db, Some(&mapper)) { let semantic_id = LuaSemanticDeclId::TypeDecl(type_decl_id.clone()); let description = db .get_property_index() From 9a4ec66ff4dc91f690f1011115792abfbeb19b32 Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Fri, 22 May 2026 22:52:07 +0800 Subject: [PATCH 08/10] fix generic --- .../src/compilation/analyzer/lua/stats.rs | 85 +++++++++--- .../test/for_range_var_infer_test.rs | 29 +++++ .../src/compilation/test/member_infer_test.rs | 28 ++++ .../semantic/generic/inference/infer_types.rs | 122 ++++++++++++------ .../src/semantic/generic/inference/mod.rs | 4 + .../src/semantic/generic/inference/tests.rs | 74 ++++++++++- .../instantiate_special_generic.rs | 29 +---- 7 files changed, 283 insertions(+), 88 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 4cd8389a9..b055e22b6 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -1,8 +1,9 @@ use emmylua_parser::{ BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr, LuaIndexKey, - LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaTableExpr, LuaTableField, LuaVarExpr, - PathTrait, + LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaSyntaxId, LuaTableExpr, LuaTableField, + LuaVarExpr, PathTrait, }; +use hashbrown::HashSet; use crate::{ InFiled, InferFailReason, LuaBuiltinAttributeKind, LuaLspOptimizationCode, LuaMemberKey, @@ -497,12 +498,16 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> return Some(()); } + let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); + if is_table_field_already_analyzed(analyzer, &field, &member_id) { + return Some(()); + } + if let Some(field_key) = field.get_field_key() { if let LuaIndexKey::Expr(_) = &field_key { // Decl analysis leaves `[expr] = value` fields unresolved. If the key // already resolves here, materialize the member now. let db = &mut *analyzer.db; - let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); if db.get_member_index().get_member(&member_id).is_none() { let cache = analyzer .context @@ -530,7 +535,6 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> } } - let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); if analyzer .db .get_type_index() @@ -567,6 +571,30 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> Some(()) } +fn is_table_field_already_analyzed( + analyzer: &LuaAnalyzer, + field: &LuaTableField, + member_id: &LuaMemberId, +) -> bool { + if analyzer + .db + .get_type_index() + .get_type_cache(&(*member_id).into()) + .is_none() + { + return false; + } + + match field.get_field_key() { + Some(LuaIndexKey::Expr(_)) => analyzer + .db + .get_member_index() + .get_member(member_id) + .is_some(), + _ => true, + } +} + fn special_assign_pattern( analyzer: &mut LuaAnalyzer, type_owner: LuaTypeOwner, @@ -632,58 +660,81 @@ fn get_delayed_definition_decl_id( } fn pre_analyze_call_arg_table_fields(analyzer: &mut LuaAnalyzer, expr: &LuaExpr) { - pre_analyze_nested_table_fields(analyzer, expr.clone()); + let mut analyzed_fields = HashSet::new(); + pre_analyze_nested_table_fields(analyzer, expr.clone(), 0, &mut analyzed_fields); } -fn pre_analyze_nested_table_fields(analyzer: &mut LuaAnalyzer, expr: LuaExpr) { +fn pre_analyze_nested_table_fields( + analyzer: &mut LuaAnalyzer, + expr: LuaExpr, + depth: usize, + analyzed_fields: &mut HashSet, +) { + if depth >= 250 { + return; + } + + let next_depth = depth + 1; match expr { LuaExpr::CallExpr(call_expr) => { if let Some(prefix_expr) = call_expr.get_prefix_expr() { - pre_analyze_nested_table_fields(analyzer, prefix_expr); + pre_analyze_nested_table_fields(analyzer, prefix_expr, next_depth, analyzed_fields); } if let Some(args_list) = call_expr.get_args_list() { for arg in args_list.get_args() { - pre_analyze_nested_table_fields(analyzer, arg); + pre_analyze_nested_table_fields(analyzer, arg, next_depth, analyzed_fields); } } } LuaExpr::TableExpr(table_expr) => { for field in table_expr.get_fields() { if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() { - pre_analyze_nested_table_fields(analyzer, key_expr); + pre_analyze_nested_table_fields( + analyzer, + key_expr, + next_depth, + analyzed_fields, + ); } if let Some(value_expr) = field.get_value_expr() { - pre_analyze_nested_table_fields(analyzer, value_expr); + pre_analyze_nested_table_fields( + analyzer, + value_expr, + next_depth, + analyzed_fields, + ); } - analyze_table_field(analyzer, field.clone()); + if analyzed_fields.insert(field.get_syntax_id()) { + analyze_table_field(analyzer, field.clone()); + } } } LuaExpr::BinaryExpr(binary_expr) => { if let Some((left, right)) = binary_expr.get_exprs() { - pre_analyze_nested_table_fields(analyzer, left); - pre_analyze_nested_table_fields(analyzer, right); + pre_analyze_nested_table_fields(analyzer, left, next_depth, analyzed_fields); + pre_analyze_nested_table_fields(analyzer, right, next_depth, analyzed_fields); } } LuaExpr::UnaryExpr(unary_expr) => { if let Some(inner_expr) = unary_expr.get_expr() { - pre_analyze_nested_table_fields(analyzer, inner_expr); + pre_analyze_nested_table_fields(analyzer, inner_expr, next_depth, analyzed_fields); } } LuaExpr::ParenExpr(paren_expr) => { if let Some(inner_expr) = paren_expr.get_expr() { - pre_analyze_nested_table_fields(analyzer, inner_expr); + pre_analyze_nested_table_fields(analyzer, inner_expr, next_depth, analyzed_fields); } } LuaExpr::IndexExpr(index_expr) => { if let Some(prefix_expr) = index_expr.get_prefix_expr() { - pre_analyze_nested_table_fields(analyzer, prefix_expr); + pre_analyze_nested_table_fields(analyzer, prefix_expr, next_depth, analyzed_fields); } if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { - pre_analyze_nested_table_fields(analyzer, key_expr); + pre_analyze_nested_table_fields(analyzer, key_expr, next_depth, analyzed_fields); } } LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => {} diff --git a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs index 87928325c..e4291933e 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs @@ -243,6 +243,35 @@ mod test { assert_eq!(value_union.into_set(), expected_values); } + #[test] + fn test_pairs_metamethod_extracts_iterator_from_multi_return() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class PairSource + local PairSource + + ---@return fun(): string, number + ---@return table + ---@return nil + function PairSource:__pairs() + end + + ---@type PairSource + local source + + for k, v in pairs(source) do + key_out = k + value_out = v + end + "#, + ); + + assert_eq!(ws.expr_ty("key_out"), LuaType::String); + assert_eq!(ws.expr_ty("value_out"), LuaType::Number); + } + #[test] fn test_issue_291() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 00307054f..14a0f93c7 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -395,6 +395,34 @@ mod test { ); } + #[test] + fn test_call_arg_nested_table_expr_key_doc_const() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + local function id(value) + return value + end + + ---@type 'field' + local key = "field" + local t = id({ nested = { [key] = 1 } }).nested + value = t[key] + "#, + ); + + let value_ty = ws.expr_ty("value"); + assert!( + matches!(value_ty, LuaType::Integer | LuaType::IntegerConst(_)), + "expected integer type, got {:?}", + value_ty + ); + } + #[test] fn test_union_member_access_preserves_never() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs index cd1a57133..c5cce4de0 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs @@ -6,10 +6,10 @@ use rowan::NodeOrToken; use smol_str::SmolStr; use crate::{ - InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, LuaGenericType, LuaMemberInfo, - LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, LuaTupleType, LuaType, LuaTypeNode, - LuaUnionType, SemanticDeclLevel, VariadicType, check_type_compact, infer_node_semantic_decl, - instantiate_type_generic, + GenericTplId, InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, LuaGenericType, + LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, LuaTupleType, LuaType, + LuaTypeNode, LuaUnionType, SemanticDeclLevel, VariadicType, check_type_compact, + infer_node_semantic_decl, instantiate_type_generic, semantic::{ generic::TypeMapper, member::{find_index_operations, get_member_map}, @@ -861,10 +861,11 @@ fn param_list_infer_types( LuaType::Variadic(inner) => { let i = i + target_offset; if i >= targets.len() { - if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { - context.insert_type( - tpl_ref.get_tpl_id(), - InferenceCandidate::ordinary(LuaType::Nil), + if let VariadicType::Base(base) = inner.deref() { + insert_tpl_ref_candidate( + context, + base, + LuaType::Nil, variance, false, priority, @@ -873,8 +874,8 @@ fn param_list_infer_types( break; } - if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() - && let Some(len) = context.inferred_variadic_len(tpl_ref.get_tpl_id()) + if let Some((tpl_id, _)) = variadic_base_tpl_ref(inner.deref()) + && let Some(len) = context.inferred_variadic_len(tpl_id) { target_offset += len - 1; continue; @@ -933,26 +934,24 @@ pub(in crate::semantic::generic) fn return_type_infer_types( match target_variadic.deref() { VariadicType::Base(target_base) => match source_variadic.deref() { VariadicType::Base(source_base) => { - if let LuaType::TplRef(tpl_ref) = source_base { - context.insert_type( - tpl_ref.get_tpl_id(), - InferenceCandidate::ordinary(target_base.clone()), - variance, - false, - priority, - ); - } + insert_tpl_ref_candidate( + context, + source_base, + target_base.clone(), + variance, + false, + priority, + ); } VariadicType::Multi(source_multi) => { for ret_type in source_multi { match ret_type { LuaType::Variadic(inner) => { - if let VariadicType::Base(base) = inner.deref() - && let LuaType::TplRef(tpl_ref) = base - { - context.insert_type( - tpl_ref.get_tpl_id(), - InferenceCandidate::ordinary(target_base.clone()), + if let VariadicType::Base(base) = inner.deref() { + insert_tpl_ref_candidate( + context, + base, + target_base.clone(), variance, false, priority, @@ -960,16 +959,16 @@ pub(in crate::semantic::generic) fn return_type_infer_types( } break; } - LuaType::TplRef(tpl_ref) => { - context.insert_type( - tpl_ref.get_tpl_id(), - InferenceCandidate::ordinary(target_base.clone()), + _ => { + insert_tpl_ref_candidate( + context, + ret_type, + target_base.clone(), variance, false, priority, ); } - _ => {} } } } @@ -1021,14 +1020,56 @@ pub(in crate::semantic::generic) fn return_type_infer_types( Ok(()) } +fn tpl_ref_info(ty: &LuaType) -> Option<(GenericTplId, bool)> { + match ty { + LuaType::TplRef(tpl_ref) => Some((tpl_ref.get_tpl_id(), false)), + LuaType::ConstTplRef(tpl_ref) => Some((tpl_ref.get_tpl_id(), true)), + _ => None, + } +} + +fn variadic_base_tpl_ref(variadic: &VariadicType) -> Option<(GenericTplId, bool)> { + let VariadicType::Base(base) = variadic else { + return None; + }; + tpl_ref_info(base) +} + +fn tpl_ref_candidate(is_const_tpl: bool, ty: LuaType) -> InferenceCandidate { + if is_const_tpl { + InferenceCandidate::const_preserving(ty) + } else { + InferenceCandidate::ordinary(ty) + } +} + +fn insert_tpl_ref_candidate( + context: &mut InferenceContext, + source: &LuaType, + target: LuaType, + variance: InferenceVariance, + top_level: bool, + priority: InferencePriority, +) { + if let Some((tpl_id, is_const_tpl)) = tpl_ref_info(source) { + context.insert_type( + tpl_id, + tpl_ref_candidate(is_const_tpl, target), + variance, + top_level, + priority, + ); + } +} + fn function_varargs_infer_types( context: &mut InferenceContext, variadic: &VariadicType, target_rest_params: &[(String, Option)], ) -> Result<(), InferFailReason> { - if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = variadic { + if let Some((tpl_id, _)) = variadic_base_tpl_ref(variadic) { context.add_variadic_params( - tpl_ref.get_tpl_id(), + tpl_id, target_rest_params .iter() .map(|(name, ty)| (name.clone(), ty.clone())) @@ -1149,19 +1190,19 @@ pub(in crate::semantic::generic) fn variadic_infer_types( } break; } - LuaType::TplRef(tpl_ref) => { + _ => { let Some(target) = target_rest_types.get(i) else { break; }; - context.insert_type( - tpl_ref.get_tpl_id(), - InferenceCandidate::ordinary(target.clone()), + insert_tpl_ref_candidate( + context, + ret_type, + target.clone(), variance, target == original_target, priority, ); } - _ => {} } } } @@ -1311,9 +1352,9 @@ fn tuple_infer_types( return Err(InferFailReason::None); }; if let LuaType::Variadic(inner) = last_type - && let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() + && let Some((tpl_id, _)) = variadic_base_tpl_ref(inner.deref()) { - context.add_variadic_base(tpl_ref.get_tpl_id(), target_array.get_base().clone()); + context.add_variadic_base(tpl_id, target_array.get_base().clone()); } } _ => {} @@ -1502,7 +1543,8 @@ fn try_handle_pairs_metamethod( } .ok_or(InferFailReason::None)?; - let final_return_type = match meta_return { + let iterator_func = meta_return.get_result_slot_type(0).unwrap_or(meta_return); + let final_return_type = match iterator_func { LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), LuaType::Signature(signature_id) => context .db diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs index d9ae9752d..32e77624e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs @@ -122,6 +122,10 @@ pub(in crate::semantic::generic) fn is_literal_candidate(ty: &LuaType) -> bool { | LuaType::DocBooleanConst(_) | LuaType::TableConst(_) => true, LuaType::Union(union) => union.into_vec().iter().any(is_literal_candidate), + LuaType::MultiLineUnion(union) => union + .get_unions() + .iter() + .any(|(ty, _)| is_literal_candidate(ty)), LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_literal_candidate), LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => is_literal_candidate(base), diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs index b826377c0..5b4afa96e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs @@ -5,11 +5,11 @@ use smol_str::SmolStr; use super::{ InferenceCandidate, InferenceContext, InferencePriority, InferenceResult, InferenceVariance, - context::inference_result_to_mapper_value, + context::inference_result_to_mapper_value, is_literal_candidate, return_type_infer_types, }; use crate::{ - CacheOptions, DbIndex, FileId, GenericTpl, GenericTplId, LuaInferCache, LuaType, LuaTypeDeclId, - TypeOps, + CacheOptions, DbIndex, FileId, GenericTpl, GenericTplId, InferGuard, LuaInferCache, + LuaMultiLineUnion, LuaTupleStatus, LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, VariadicType, semantic::generic::{TypeMapperValue, get_mapped_value}, }; @@ -88,3 +88,71 @@ fn variadic_params_mapper_normalizes_def_types() { TypeMapperValue::Params(vec![("value".to_string(), Some(LuaType::Ref(type_id)))]) ); } + +#[test] +fn literal_candidate_detects_multi_line_union_members() { + let literal_union = LuaType::MultiLineUnion( + LuaMultiLineUnion::new(vec![ + (LuaType::DocStringConst(SmolStr::new("left").into()), None), + (LuaType::Integer, Some("fallback".to_string())), + ]) + .into(), + ); + assert!(is_literal_candidate(&literal_union)); + + let non_literal_union = LuaType::MultiLineUnion( + LuaMultiLineUnion::new(vec![ + (LuaType::String, None), + (LuaType::Integer, Some("fallback".to_string())), + ]) + .into(), + ); + assert!(!is_literal_candidate(&non_literal_union)); +} + +#[test] +fn return_variadic_const_tpl_ref_preserves_structural_base() { + let db = DbIndex::new(); + let mut cache = LuaInferCache::new(FileId::VIRTUAL, CacheOptions::default()); + let mut context = InferenceContext::new(&db, &mut cache, None); + let tpl_id = GenericTplId::Func(0); + let tpl = Arc::new(GenericTpl::new( + tpl_id, + SmolStr::new("T").into(), + None, + None, + )); + context.prepare_inference_slots(HashSet::from([tpl_id])); + + let literal_tuple = LuaType::Tuple( + LuaTupleType::new( + vec![ + LuaType::StringConst(SmolStr::new("mode").into()), + LuaType::IntegerConst(1), + ], + LuaTupleStatus::InferResolve, + ) + .into(), + ); + let source = LuaType::Variadic(VariadicType::Base(LuaType::ConstTplRef(tpl.clone())).into()); + let target = LuaType::Variadic(VariadicType::Base(literal_tuple.clone()).into()); + + return_type_infer_types( + &mut context, + &source, + &target, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + None, + &InferGuard::new(), + ) + .expect("return variadic inference"); + + let return_type = LuaType::ConstTplRef(tpl.clone()); + let mapper = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &mapper).and_then(|value| value.raw_type()), + Some(literal_tuple) + ); +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 04a0d0975..916245e38 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -74,10 +74,7 @@ pub(super) fn instantiate_alias_call( instantiate_select_call(&operands[0], &operands[1]) } - LuaAliasCallKind::Unpack => { - let operands = resolve_unpack_operands(context, frame, operand_exprs); - instantiate_unpack_call(context.db, &operands) - } + LuaAliasCallKind::Unpack => instantiate_unpack_call(context.db, &operands), LuaAliasCallKind::RawGet => { if operands.len() != 2 { return LuaType::Unknown; @@ -240,30 +237,6 @@ fn instantiate_select_call(source: &LuaType, index: &LuaType) -> LuaType { } } -fn resolve_unpack_operands( - context: &GenericInstantiateContext, - frame: GenericInstantiateFrame, - operand_exprs: &[LuaType], -) -> Vec { - operand_exprs - .iter() - .enumerate() - .map(|(index, operand)| { - if index != 0 { - return instantiate_type_generic_inner(context, frame, operand); - } - let raw = match operand { - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { - get_mapped_value(tpl_ref.get_tpl_id(), &context.mapper) - .and_then(|value| value.raw_type()) - } - _ => None, - }; - raw.unwrap_or_else(|| instantiate_type_generic_inner(context, frame, operand)) - }) - .collect() -} - fn instantiate_unpack_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.is_empty() { return LuaType::Unknown; From 96805d074b01931e6bccdfdf2de6e59699688284 Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Sat, 23 May 2026 20:18:40 +0800 Subject: [PATCH 09/10] =?UTF-8?q?refactor(generic):=20`@type`=E7=8E=B0?= =?UTF-8?q?=E5=9C=A8=E4=BC=9A=E8=AE=A1=E7=AE=97=E6=B3=9B=E5=9E=8B=E5=B1=95?= =?UTF-8?q?=E5=BC=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../compilation/analyzer/doc/type_ref_tags.rs | 153 +++++++++++++++++- .../src/compilation/test/generic_test.rs | 79 ++++++++- .../src/db_index/type/humanize_type.rs | 72 ++++++++- .../src/db_index/type/type_decl.rs | 39 ++++- .../generic/instantiate_type/context.rs | 10 +- .../infer_call_func_generic.rs | 13 +- .../semantic/generic/instantiate_type/mod.rs | 15 +- .../src/semantic/generic/test.rs | 97 ++++++++++- .../src/semantic/infer/infer_call/mod.rs | 14 +- .../src/semantic/infer/infer_name.rs | 34 ++-- 10 files changed, 480 insertions(+), 46 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index de0cc3320..55e91dbe2 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use emmylua_parser::{ LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaDocDescriptionOwner, LuaDocTagAs, LuaDocTagCast, LuaDocTagModule, LuaDocTagOther, LuaDocTagOverload, LuaDocTagParam, LuaDocTagReturn, @@ -12,13 +14,14 @@ use super::{ tags::{find_owner_closure, get_owner_id_or_report}, }; use crate::{ - InFiled, JsonSchemaFile, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeOwner, OperatorFunction, - SignatureReturnStatus, TypeOps, + DbIndex, InFiled, JsonSchemaFile, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeDeclId, + LuaTypeOwner, OperatorFunction, SignatureReturnStatus, TplResolvePolicy, TypeMapper, TypeOps, compilation::analyzer::common::bind_type, db_index::{ - LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaMemberId, - LuaOperator, LuaSemanticDeclId, LuaSignatureId, LuaType, + LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaGenericType, + LuaMemberId, LuaOperator, LuaSemanticDeclId, LuaSignatureId, LuaType, }, + instantiate_type_generic, instantiate_type_generic_full, }; use crate::{ LuaAttributeUse, @@ -36,6 +39,7 @@ pub fn analyze_type(analyzer: &mut DocAnalyzer, tag: LuaDocTagType) -> Option<() let mut type_list = Vec::new(); for lua_doc_type in tag.get_type_list() { let type_ref = infer_type(&mut analyzer.type_context, lua_doc_type); + let type_ref = maybe_instantiate_doc_type(analyzer.get_db(), type_ref); type_list.push(type_ref); } @@ -44,6 +48,147 @@ pub fn analyze_type(analyzer: &mut DocAnalyzer, tag: LuaDocTagType) -> Option<() Some(()) } +fn maybe_instantiate_doc_type(db: &DbIndex, type_ref: LuaType) -> LuaType { + let type_decl_id = match &type_ref { + LuaType::Ref(type_id) => Some(type_id.clone()), + LuaType::Generic(generic) => Some(generic.get_base_type_id()), + _ => None, + }; + let has_alias_chain = type_decl_id + .as_ref() + .and_then(|type_decl_id| db.get_type_index().get_type_decl(type_decl_id)) + .is_some_and(|type_decl| { + type_decl.is_alias() + && matches!( + type_decl.get_alias_ref(), + Some(LuaType::Ref(_) | LuaType::Generic(_)) + ) + }); + let contain_tpl = type_ref.contain_tpl(); + + if !contain_tpl && !has_alias_chain { + return type_ref; + } + + let mapper = TypeMapper::empty(); + + if contain_tpl { + if has_alias_chain { + let (current_type, current_id) = match &type_ref { + LuaType::Generic(generic) => { + let params = generic + .get_params() + .iter() + .map(|param| { + instantiate_type_generic_full( + db, + param, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ) + }) + .collect::>(); + ( + LuaType::Generic( + LuaGenericType::new(generic.get_base_type_id(), params).into(), + ), + generic.get_base_type_id(), + ) + } + LuaType::Ref(type_id) => (LuaType::Ref(type_id.clone()), type_id.clone()), + _ => { + return instantiate_type_generic_full( + db, + &type_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ); + } + }; + + return instantiate_doc_alias_chain(db, current_type, current_id, &mapper); + } + return instantiate_type_generic_full( + db, + &type_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ); + } + + if has_alias_chain { + let instantiated = instantiate_type_generic(db, &type_ref, &mapper); + if !matches!(instantiated, LuaType::Any | LuaType::Unknown) { + return instantiated; + } + } + + type_ref +} + +fn instantiate_doc_alias_chain( + db: &DbIndex, + mut current_type: LuaType, + mut current_id: LuaTypeDeclId, + mapper: &TypeMapper, +) -> LuaType { + let mut visited = HashSet::new(); + loop { + if !visited.insert(current_id.clone()) { + return current_type; + } + + let Some(type_decl) = db.get_type_index().get_type_decl(¤t_id) else { + return current_type; + }; + if !type_decl.is_alias() { + return current_type; + } + + let Some(origin) = type_decl.get_alias_ref() else { + return current_type; + }; + let next_id = match origin { + LuaType::Ref(type_id) => type_id.clone(), + LuaType::Generic(generic) => generic.get_base_type_id(), + _ => return current_type, + }; + + let params = match ¤t_type { + LuaType::Generic(generic) => generic.get_params().clone(), + LuaType::Ref(_) => Vec::new(), + _ => return current_type, + }; + let alias_mapper = TypeMapper::from_alias(db, params, ¤t_id); + let alias_mapper = TypeMapper::merge(Some(alias_mapper), mapper.clone()); + + current_type = match origin { + LuaType::Generic(generic) => { + let params = generic + .get_params() + .iter() + .map(|param| { + instantiate_type_generic_full( + db, + param, + &alias_mapper, + None, + TplResolvePolicy::PreserveTplRef, + ) + }) + .collect::>(); + LuaType::Generic(LuaGenericType::new(next_id.clone(), params).into()) + } + LuaType::Ref(type_id) => LuaType::Ref(type_id.clone()), + _ => return current_type, + }; + current_id = next_id; + } +} + fn bind_type_to_owner( analyzer: &mut DocAnalyzer, tag: &impl LuaAstNode, diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 1ce5adac2..e562bc776 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -139,7 +139,7 @@ mod test { "#, ); let a_ty = ws.expr_ty("a"); - assert_eq!(a_ty, ws.ty("unknown")); + assert_eq!(a_ty, LuaType::Unknown); } // Currently fails: @@ -1055,6 +1055,83 @@ mod test { assert_eq!(ws.humanize_type(value_ty), "Base"); } + #[test] + fn test_doc_type_binding_instantiates_alias_origin() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Box { value: T } + ---@alias Forward Box + + ---@type Forward + local value + + result = value + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("{ value: string }")); + } + + #[test] + fn test_doc_type_binding_preserves_residual_aliases() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Extract T extends U and T or never + ---@alias KeepA Extract + ---@alias ForwardConditional KeepA + + ---@alias Copy { [K in keyof T]: T[K] } + ---@alias ForwardMapped Copy + + ---@generic T + local function f() + ---@type ForwardConditional + local cond + + ---@type ForwardMapped + local mapped + + cond_result = cond + mapped_result = mapped + end + "#, + ); + + let cond_ty = ws.expr_ty("cond_result"); + let cond_desc = ws.humanize_type(cond_ty); + assert_eq!(cond_desc, r#"Extract"#); + + let mapped_ty = ws.expr_ty("mapped_result"); + assert_eq!(ws.humanize_type(mapped_ty), "Copy"); + } + + #[test] + fn test_detailed_humanize_shows_mapped_alias_body() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K] } + ---@alias ForwardMapped Copy + + ---@generic T + local function f() + ---@type ForwardMapped + local mapped + + mapped_result = mapped + end + "#, + ); + + let mapped_ty = ws.expr_ty("mapped_result"); + assert_eq!(ws.humanize_type(mapped_ty.clone()), "Copy"); + let mapped_desc = ws.humanize_type_detailed(mapped_ty); + assert!(mapped_desc.starts_with("Copy = { [K in keyof T]: ")); + } + #[test] fn test_partial_generic_type_fills_trailing_default() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index 05b19bed3..4bcd26cef 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -5,8 +5,9 @@ use itertools::Itertools; use crate::{ AsyncState, DbIndex, LuaAliasCallType, LuaConditionalType, LuaFunctionType, LuaGenericType, - LuaIntersectionType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaSignatureId, - LuaStringTplType, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, TypeMapper, VariadicType, + LuaIntersectionType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, + LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, + TypeMapper, VariadicType, }; use super::{LuaAliasCallKind, LuaMultiLineUnion}; @@ -206,6 +207,7 @@ impl<'a> TypeHumanizer<'a> { LuaType::MultiLineUnion(multi_union) => { self.write_multi_line_union_type(multi_union, w) } + LuaType::Mapped(mapped) => self.write_mapped_type(mapped, w), LuaType::TypeGuard(inner) => { w.write_str("TypeGuard<")?; let saved = self.level; @@ -717,7 +719,9 @@ impl<'a> TypeHumanizer<'a> { } let mapper = TypeMapper::from_type_array(generic.get_params().clone()); - if let Some(origin_type) = type_decl.get_alias_origin(self.db, Some(&mapper)) { + if let Some(origin_type) = + type_decl.get_alias_origin_preserve_tpl(self.db, Some(&mapper)) + { w.write_str(" = ")?; let saved = self.level; self.level = self.child_level(); @@ -1024,6 +1028,68 @@ impl<'a> TypeHumanizer<'a> { } } + // ─── Mapped ───────────────────────────────────────────────────── + + fn write_mapped_type(&mut self, mapped: &LuaMappedType, w: &mut W) -> fmt::Result { + if self.level == RenderLevel::Minimal { + return w.write_str("{...}"); + } + + w.write_str("{ ")?; + if mapped.is_readonly { + w.write_str("readonly ")?; + } + + let saved = self.level; + self.level = self.child_level(); + + w.write_char('[')?; + w.write_str(mapped.param.1.name.as_str())?; + w.write_str(" in ")?; + self.write_mapped_constraint(mapped.param.1.type_constraint.as_ref(), w)?; + w.write_char(']')?; + if mapped.is_optional { + w.write_char('?')?; + } + w.write_str(": ")?; + self.write_mapped_value(&mapped.value, w)?; + + self.level = saved; + w.write_str(" }") + } + + fn write_mapped_constraint( + &mut self, + constraint: Option<&LuaType>, + w: &mut W, + ) -> fmt::Result { + match constraint { + Some(LuaType::Call(alias_call)) + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 => + { + w.write_str("keyof ")?; + self.write_type(&alias_call.get_operands()[0], w) + } + Some(ty) => self.write_type(ty, w), + None => w.write_str("unknown"), + } + } + + fn write_mapped_value(&mut self, value: &LuaType, w: &mut W) -> fmt::Result { + if let LuaType::Call(alias_call) = value + && alias_call.get_call_kind() == LuaAliasCallKind::Index + && alias_call.get_operands().len() == 2 + { + self.write_type(&alias_call.get_operands()[0], w)?; + w.write_char('[')?; + self.write_type(&alias_call.get_operands()[1], w)?; + return w.write_char(']'); + } + + self.write_type(value, w) + } + // ─── helper: write a table member (key: type) ─────────────────── fn write_table_member_field( diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs index 7ca509d87..8b655303d 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs @@ -7,8 +7,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use smol_str::SmolStr; use crate::{ - DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TypeMapper, db_index::WorkspaceId, - instantiate_type_generic, + DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TplResolvePolicy, TypeMapper, + db_index::WorkspaceId, instantiate_type_generic, instantiate_type_generic_full, }; use super::{LuaType, LuaUnionType}; @@ -155,6 +155,41 @@ impl LuaTypeDecl { } } + pub fn get_alias_origin_preserve_tpl( + &self, + db: &DbIndex, + mapper: Option<&TypeMapper>, + ) -> Option { + match &self.extra { + LuaTypeExtra::Alias { + origin: Some(origin), + } => { + let mapper = match mapper { + Some(mapper) => mapper, + None => return Some(origin.clone()), + }; + + let type_decl_id = self.get_id(); + if db + .get_type_index() + .get_generic_params(&type_decl_id) + .is_none() + { + return Some(origin.clone()); + } + + Some(instantiate_type_generic_full( + db, + origin, + mapper, + None, + TplResolvePolicy::PreserveTplRef, + )) + } + _ => None, + } + } + pub fn get_alias_ref(&self) -> Option<&LuaType> { match &self.extra { LuaTypeExtra::Alias { origin, .. } => origin.as_ref(), diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs index 316a19f5e..ee3f81c35 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs @@ -8,7 +8,7 @@ const MAX_INSTANTIATION_DEPTH: usize = 128; const MAX_ALIAS_STACK: usize = 32; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(in crate::semantic::generic) enum UninferredTplPolicy { +pub enum TplResolvePolicy { /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. Fallback, /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. @@ -25,7 +25,7 @@ pub(in crate::semantic::generic) struct GenericInstantiateContext<'a> { #[derive(Debug, Clone, Copy)] pub(in crate::semantic::generic) struct GenericInstantiateFrame { - policy: UninferredTplPolicy, + policy: TplResolvePolicy, depth: usize, } @@ -45,7 +45,7 @@ impl<'a> GenericInstantiateContext<'a> { pub(super) fn root_frame(&self) -> GenericInstantiateFrame { GenericInstantiateFrame { - policy: UninferredTplPolicy::Fallback, + policy: TplResolvePolicy::Fallback, depth: 0, } } @@ -94,12 +94,12 @@ impl<'a> GenericInstantiateContext<'a> { } impl GenericInstantiateFrame { - pub(super) fn with_policy(self, policy: UninferredTplPolicy) -> Self { + pub(super) fn with_policy(self, policy: TplResolvePolicy) -> Self { Self { policy, ..self } } pub(super) fn should_preserve_tpl_ref(&self) -> bool { - self.policy == UninferredTplPolicy::PreserveTplRef + self.policy == TplResolvePolicy::PreserveTplRef } pub(super) fn enter(self) -> Option { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs index 60f188d3f..fa7c8157b 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs @@ -29,7 +29,8 @@ use crate::{ }; use super::{ - TypeMapper, TypeMapperValue, instantiate_type_generic, instantiate_type_generic_with_self, + TplResolvePolicy, TypeMapper, TypeMapperValue, instantiate_type_generic, + instantiate_type_generic_full, }; pub fn infer_call_func_generic( @@ -81,9 +82,13 @@ pub fn infer_call_func_generic( .then(|| infer_self_type(db, cache, &call_expr)) .flatten(); let func_ty = LuaType::DocFunction(func.clone().into()); - if let LuaType::DocFunction(f) = - instantiate_type_generic_with_self(db, &func_ty, &mapper, self_type.as_ref()) - { + if let LuaType::DocFunction(f) = instantiate_type_generic_full( + db, + &func_ty, + &mapper, + self_type.as_ref(), + TplResolvePolicy::Fallback, + ) { Ok(f.deref().clone()) } else { Ok(func.clone()) diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index b343f0441..ec1c1c430 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -22,7 +22,8 @@ use super::{TypeMapper, TypeMapperValue, get_mapped_value}; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; -use context::{GenericInstantiateContext, GenericInstantiateFrame, UninferredTplPolicy}; +pub use context::TplResolvePolicy; +use context::{GenericInstantiateContext, GenericInstantiateFrame}; pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infer_self_type}; pub(in crate::semantic::generic) use inference_widening::{ is_primitive_or_literal_type, regularize_tpl_candidate_type, widen_tpl_candidate_type, @@ -31,17 +32,18 @@ use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_ pub use instantiate_special_generic::get_keyof_members; pub fn instantiate_type_generic(db: &DbIndex, ty: &LuaType, mapper: &TypeMapper) -> LuaType { - instantiate_type_generic_with_self(db, ty, mapper, None) + instantiate_type_generic_full(db, ty, mapper, None, TplResolvePolicy::Fallback) } -pub fn instantiate_type_generic_with_self( +pub fn instantiate_type_generic_full( db: &DbIndex, ty: &LuaType, mapper: &TypeMapper, self_type: Option<&LuaType>, + root_policy: TplResolvePolicy, ) -> LuaType { let context = GenericInstantiateContext::new(db, mapper, self_type); - let frame = context.root_frame(); + let frame = context.root_frame().with_policy(root_policy); match ty { LuaType::DocFunction(doc_func) => instantiate_doc_function(&context, frame, doc_func), _ => instantiate_type_generic_inner(&context, frame, ty), @@ -80,7 +82,7 @@ pub(super) fn instantiate_type_generic_inner( } instantiate_doc_function( context, - frame.with_policy(UninferredTplPolicy::PreserveTplRef), + frame.with_policy(TplResolvePolicy::PreserveTplRef), doc_func, ) } @@ -555,6 +557,9 @@ fn instantiate_const_tpl_ref( if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { TypeMapperValue::None => { + if frame.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { + return LuaType::ConstTplRef(tpl.clone().into()); + } return instantiate_uninferred_tpl_fallback(tpl, context, frame); } TypeMapperValue::Type(value) => { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 0581483fc..967ab7103 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -3,7 +3,10 @@ mod test { use hashbrown::HashMap; use std::sync::Arc; - use super::super::instantiate_type::{regularize_tpl_candidate_type, widen_tpl_candidate_type}; + use super::super::instantiate_type::{ + TplResolvePolicy, instantiate_type_generic_full, regularize_tpl_candidate_type, + widen_tpl_candidate_type, + }; use crate::{ AsyncState, DbIndex, DiagnosticCode, GenericTpl, GenericTplId, LuaArrayType, LuaFunctionType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaTupleStatus, @@ -12,6 +15,15 @@ mod test { }; use smol_str::SmolStr; + fn func_tpl(idx: u32, name: &str) -> Arc { + Arc::new(GenericTpl::new( + GenericTplId::Func(idx), + SmolStr::new(name).into(), + None, + None, + )) + } + #[test] fn test_variadic_func() { let mut ws = VirtualWorkspace::new(); @@ -675,6 +687,89 @@ result = { ); } + #[test] + fn test_preserve_uninferred_keeps_unmapped_tpl_ref_and_const_tpl_ref() { + let db = DbIndex::new(); + let mapper = TypeMapper::from_uninferred(vec![GenericTplId::Func(0)]); + let tpl = func_tpl(0, "T0"); + + let tpl_ref = LuaType::TplRef(tpl.clone()); + assert_eq!( + instantiate_type_generic_full( + &db, + &tpl_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ), + tpl_ref + ); + + let const_tpl_ref = LuaType::ConstTplRef(tpl.clone()); + assert_eq!( + instantiate_type_generic_full( + &db, + &const_tpl_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ), + const_tpl_ref + ); + } + + #[test] + fn test_preserve_uninferred_applies_concrete_values_while_retaining_residual_templates() { + let db = DbIndex::new(); + let concrete = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let unresolved = + TypeMapper::from_uninferred(vec![GenericTplId::Func(1), GenericTplId::Func(2)]); + let mapper = TypeMapper::merge(Some(concrete), unresolved); + let t0 = func_tpl(0, "T0"); + let t1 = func_tpl(1, "T1"); + let t2 = func_tpl(2, "T2"); + + // 这段输入可以理解成 Lua 伪代码: + // ---@type [T0, T1, const T2] + // local value + // 套用 mapper 后,T0 会被具体化为 string, + // 但 T1 / T2 仍然要保留为残余模板,不能被过早折叠掉。 + let ty = LuaType::Tuple( + LuaTupleType::new( + vec![ + LuaType::TplRef(t0), + LuaType::TplRef(t1.clone()), + LuaType::ConstTplRef(t2.clone()), + ], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + let preserved = instantiate_type_generic_full( + &db, + &ty, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ); + let expected = LuaType::Tuple( + LuaTupleType::new( + vec![ + LuaType::String, + LuaType::TplRef(t1), + LuaType::ConstTplRef(t2), + ], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + + assert_eq!(preserved, expected); + } + #[test] fn test_123() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index b24e659cc..4d000ba9c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -22,8 +22,8 @@ use crate::{ }, }; use crate::{ - build_self_type, infer_call_func_generic, infer_self_type, instantiate_type_generic_with_self, - semantic::infer_expr, + TplResolvePolicy, build_self_type, infer_call_func_generic, infer_self_type, + instantiate_type_generic_full, semantic::infer_expr, }; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -367,9 +367,13 @@ fn infer_type_doc_function( let mapper = TypeMapper::empty(); let self_type = build_self_type(db, call_expr_type); let func_ty = LuaType::DocFunction(f.clone()); - if let LuaType::DocFunction(f) = - instantiate_type_generic_with_self(db, &func_ty, &mapper, Some(&self_type)) - { + if let LuaType::DocFunction(f) = instantiate_type_generic_full( + db, + &func_ty, + &mapper, + Some(&self_type), + TplResolvePolicy::Fallback, + ) { overloads.push(f); } } else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs index 357b9ae8f..9e60638fa 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs @@ -351,18 +351,16 @@ pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult { .ok_or(InferFailReason::None)?; if decl_ids.len() == 1 { let id = decl_ids[0]; - let typ = match db.get_type_index().get_type_cache(&id.into()) { - Some(type_cache) => type_cache.as_type().clone(), - None => return Err(InferFailReason::UnResolveDeclType(id)), - }; - // todo: 不置为 Unknown 有可能引用泛型函数中的泛型参数导致泄露, 但这样会导致丢失类型, 我们可能需要更好的办法去处理 - return if !typ.is_generic() && typ.contain_tpl() { - // This decl is located in a generic function, - // and is type contains references to generic variables - // of this function. - Ok(LuaType::Unknown) - } else { - Ok(typ) + return match db.get_type_index().get_type_cache(&id.into()) { + Some(type_cache) => { + let typ = type_cache.as_type(); + if is_bare_leaked_tpl(typ) { + Ok(LuaType::Unknown) + } else { + Ok(typ.clone()) + } + } + None => Err(InferFailReason::UnResolveDeclType(id)), }; } @@ -382,10 +380,7 @@ pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult { Some(type_cache) => { let typ = type_cache.as_type(); - if typ.contain_tpl() { - // This decl is located in a generic function, - // and is type contains references to generic variables - // of this function. + if is_bare_leaked_tpl(typ) { continue; } @@ -461,3 +456,10 @@ pub fn find_self_decl_or_member_id( _ => None, } } + +fn is_bare_leaked_tpl(typ: &LuaType) -> bool { + matches!( + typ, + LuaType::TplRef(_) | LuaType::ConstTplRef(_) | LuaType::StrTplRef(_) | LuaType::SelfInfer + ) +} From 527f7d72e0d213f8fa5850f5758d6a161b1921b9 Mon Sep 17 00:00:00 2001 From: xuhuanzy <501417909@qq.com> Date: Sat, 23 May 2026 22:29:37 +0800 Subject: [PATCH 10/10] update generic --- .../test/for_range_var_infer_test.rs | 27 +++++++++++++ .../src/db_index/type/types/test.rs | 9 +++++ .../semantic/generic/inference/infer_types.rs | 16 +++++--- .../src/semantic/generic/inference/tests.rs | 40 ++++++++++++++++++- 4 files changed, 85 insertions(+), 7 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs index e4291933e..f75e18d6f 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs @@ -272,6 +272,33 @@ mod test { assert_eq!(ws.expr_ty("value_out"), LuaType::Number); } + #[test] + fn test_pairs_metamethod_extracts_iterator_from_single_return() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class PairSource + local PairSource + + ---@return fun(): string + ---@return table + ---@return nil + function PairSource:__pairs() + end + + ---@type PairSource + local source + + for k, v in pairs(source) do + key_out = k + end + "#, + ); + + assert_eq!(ws.expr_ty("key_out"), LuaType::String); + } + #[test] fn test_issue_291() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs index ec63747de..69c5297ad 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs @@ -20,6 +20,15 @@ mod tests { ); } + #[test] + fn test_single_return_uses_result_slot_extraction() { + assert_eq!( + LuaType::String.get_result_slot_type(0), + Some(LuaType::String) + ); + assert_eq!(LuaType::String.get_result_slot_type(1), None); + } + #[test] fn test_deep_contain_tpl_uses_iterative_walk() { let mut ty = LuaType::TplRef( diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs index c5cce4de0..4cadc27b2 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs @@ -100,7 +100,7 @@ fn infer_types_inner( } } LuaType::StrTplRef(str_tpl) => { - if let LuaType::StringConst(s) = target { + if let LuaType::StringConst(s) | LuaType::DocStringConst(s) = target { let type_name = SmolStr::new(format!( "{}{}{}", str_tpl.get_prefix(), @@ -1554,13 +1554,17 @@ fn try_handle_pairs_metamethod( _ => None, }; - if let Some(LuaType::Variadic(variadic)) = &final_return_type { - let key_type = variadic.get_type(0).ok_or(InferFailReason::None)?; - let value_type = variadic.get_type(1).ok_or(InferFailReason::None)?; + if let Some(final_return_type) = &final_return_type { + let key_type = final_return_type + .get_result_slot_type(0) + .ok_or(InferFailReason::None)?; + let value_type = final_return_type + .get_result_slot_type(1) + .unwrap_or(LuaType::Nil); infer_types_inner( context, &table_params[0], - key_type, + &key_type, original_target, variance, priority, @@ -1570,7 +1574,7 @@ fn try_handle_pairs_metamethod( infer_types_inner( context, &table_params[1], - value_type, + &value_type, original_target, variance, priority, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs index 5b4afa96e..3024f4794 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs @@ -3,13 +3,15 @@ use std::sync::Arc; use hashbrown::HashSet; use smol_str::SmolStr; +use super::infer_types::infer_types; use super::{ InferenceCandidate, InferenceContext, InferencePriority, InferenceResult, InferenceVariance, context::inference_result_to_mapper_value, is_literal_candidate, return_type_infer_types, }; use crate::{ CacheOptions, DbIndex, FileId, GenericTpl, GenericTplId, InferGuard, LuaInferCache, - LuaMultiLineUnion, LuaTupleStatus, LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, VariadicType, + LuaMultiLineUnion, LuaStringTplType, LuaTupleStatus, LuaTupleType, LuaType, LuaTypeDeclId, + TypeOps, VariadicType, semantic::generic::{TypeMapperValue, get_mapped_value}, }; @@ -110,6 +112,42 @@ fn literal_candidate_detects_multi_line_union_members() { assert!(!is_literal_candidate(&non_literal_union)); } +#[test] +fn str_tpl_inference_accepts_doc_string_const() { + let db = DbIndex::new(); + let mut cache = LuaInferCache::new(FileId::VIRTUAL, CacheOptions::default()); + let mut context = InferenceContext::new(&db, &mut cache, None); + let tpl_id = GenericTplId::Func(0); + let tpl = Arc::new(GenericTpl::new( + tpl_id, + SmolStr::new("T").into(), + None, + None, + )); + context.prepare_inference_slots(HashSet::from([tpl_id])); + + let source = + LuaType::StrTplRef(LuaStringTplType::new("aaa.", "T", tpl_id, ".bbb", None).into()); + let target = LuaType::DocStringConst(SmolStr::new("xxx").into()); + + infer_types( + &mut context, + &source, + &target, + &target, + InferenceVariance::Covariant, + InferencePriority::Normal, + ) + .expect("string template inference"); + + let return_type = LuaType::TplRef(tpl.clone()); + let mapper = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &mapper).and_then(|value| value.raw_type()), + Some(LuaType::Ref(LuaTypeDeclId::global("aaa.xxx.bbb"))) + ); +} + #[test] fn return_variadic_const_tpl_ref_preserves_structural_base() { let db = DbIndex::new();