diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index da7557002..60f6d9c49 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -129,7 +129,7 @@ --- @alias std.RawGet unknown --- ---- built-in type for generic template, for match integer const and true/false +--- built-in type for generic template, used for const generic --- @alias std.ConstTpl unknown --- compact luals @@ -167,6 +167,13 @@ --- Extract from T those types that are assignable to U --- @alias Extract T extends U and T or never +--- +--- From T, pick a set of properties whose keys are in the union K +--- @alias Pick {[P in K]: T[P]; } + +--- +--- Construct a type with the properties of T except for those in type K. +--- @alias Omit Pick> --- attribute diff --git a/crates/emmylua_code_analysis/resources/std/global.lua b/crates/emmylua_code_analysis/resources/std/global.lua index 220ca2de9..eadf553c0 100644 --- a/crates/emmylua_code_analysis/resources/std/global.lua +++ b/crates/emmylua_code_analysis/resources/std/global.lua @@ -463,7 +463,7 @@ function xpcall(f, msgh, ...) end --- @generic T, Start: integer, End: integer --- @param i? std.ConstTpl --- @param j? std.ConstTpl ---- @param list T +--- @param list std.ConstTpl --- @return std.Unpack function unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/resources/std/table.lua b/crates/emmylua_code_analysis/resources/std/table.lua index ee60a7343..171d858e6 100644 --- a/crates/emmylua_code_analysis/resources/std/table.lua +++ b/crates/emmylua_code_analysis/resources/std/table.lua @@ -109,7 +109,7 @@ function table.sort(list, comp) end --- @generic T, Start: integer, End: integer --- @param i? std.ConstTpl --- @param j? std.ConstTpl ---- @param list T +--- @param list std.ConstTpl --- @return std.Unpack function table.unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index e3ad351d2..5f26acc62 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -9,12 +9,24 @@ use crate::{GenericParam, GenericTpl, GenericTplId, LuaType}; pub trait GenericIndex: std::fmt::Debug { fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId; - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam); - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option; + + fn append_generic_params( + &mut self, + scope_id: GenericScopeId, + params: Vec, + ) -> Vec { + let mut appended = Vec::new(); for param in params { - self.append_generic_param(scope_id, param); + if let Some(tpl_id) = self.append_generic_param(scope_id, param.clone()) { + appended.push(param.with_tpl_id(Some(tpl_id))); + } } + appended } fn find_generic( @@ -63,16 +75,15 @@ impl GenericIndex for FileGenericIndex { scope_id } - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option { if let Some(scope) = self.scopes.get_mut(scope_id.id) { - scope.insert_param(param); - } - } - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { - for param in params { - self.append_generic_param(scope_id, param); + return Some(scope.insert_param(param)); } + None } /// Find generic parameter by position and name. @@ -131,10 +142,12 @@ impl FileGenericScope { self.next_tpl_id.is_func() } - fn insert_param(&mut self, param: GenericParam) { - let tpl_id = self.next_tpl_id; - self.next_tpl_id = self.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); + fn insert_param(&mut self, param: GenericParam) -> GenericTplId { + let tpl_id = param.tpl_id.unwrap_or(self.next_tpl_id); + let next_idx = self.next_tpl_id.get_idx().max(tpl_id.get_idx() + 1) as u32; + self.next_tpl_id = self.next_tpl_id.with_idx(next_idx); self.params.insert(param.name.to_string(), (tpl_id, param)); + tpl_id } fn contains(&self, position: TextSize) -> bool { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index d63ed8290..0fc9ad2b9 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -142,9 +142,10 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< alias_decl.get_id() }; + let type_node = tag.get_type()?; if tag.get_generic_decl_list().is_some() { let generic_params = get_type_generic_params(analyzer, &alias_decl_id); - let range = analyzer.comment.get_range(); + let range = type_node.get_range(); let scope_id = analyzer .type_context .generic_index @@ -155,7 +156,7 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< .append_generic_params(scope_id, generic_params); } - let mut origin_type = infer_type(&mut analyzer.type_context, tag.get_type()?); + let mut origin_type = infer_type(&mut analyzer.type_context, type_node); if alias_origin_reaches(analyzer.get_db(), &origin_type, &alias_decl_id) { origin_type = LuaType::Any; } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs index ed8601db4..dcd26943f 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs @@ -68,6 +68,7 @@ fn normalize_generic_params(db: &DbIndex, params: &[GenericParam]) -> Vec Option { + let scope = self.scopes.get_mut(scope_id.id)?; let tpl_id = scope.next_tpl_id; scope.next_tpl_id = scope.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); scope.params.push((tpl_id, param)); + Some(tpl_id) } fn find_generic( diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index de0cc3320..55e91dbe2 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use emmylua_parser::{ LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaDocDescriptionOwner, LuaDocTagAs, LuaDocTagCast, LuaDocTagModule, LuaDocTagOther, LuaDocTagOverload, LuaDocTagParam, LuaDocTagReturn, @@ -12,13 +14,14 @@ use super::{ tags::{find_owner_closure, get_owner_id_or_report}, }; use crate::{ - InFiled, JsonSchemaFile, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeOwner, OperatorFunction, - SignatureReturnStatus, TypeOps, + DbIndex, InFiled, JsonSchemaFile, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeDeclId, + LuaTypeOwner, OperatorFunction, SignatureReturnStatus, TplResolvePolicy, TypeMapper, TypeOps, compilation::analyzer::common::bind_type, db_index::{ - LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaMemberId, - LuaOperator, LuaSemanticDeclId, LuaSignatureId, LuaType, + LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaGenericType, + LuaMemberId, LuaOperator, LuaSemanticDeclId, LuaSignatureId, LuaType, }, + instantiate_type_generic, instantiate_type_generic_full, }; use crate::{ LuaAttributeUse, @@ -36,6 +39,7 @@ pub fn analyze_type(analyzer: &mut DocAnalyzer, tag: LuaDocTagType) -> Option<() let mut type_list = Vec::new(); for lua_doc_type in tag.get_type_list() { let type_ref = infer_type(&mut analyzer.type_context, lua_doc_type); + let type_ref = maybe_instantiate_doc_type(analyzer.get_db(), type_ref); type_list.push(type_ref); } @@ -44,6 +48,147 @@ pub fn analyze_type(analyzer: &mut DocAnalyzer, tag: LuaDocTagType) -> Option<() Some(()) } +fn maybe_instantiate_doc_type(db: &DbIndex, type_ref: LuaType) -> LuaType { + let type_decl_id = match &type_ref { + LuaType::Ref(type_id) => Some(type_id.clone()), + LuaType::Generic(generic) => Some(generic.get_base_type_id()), + _ => None, + }; + let has_alias_chain = type_decl_id + .as_ref() + .and_then(|type_decl_id| db.get_type_index().get_type_decl(type_decl_id)) + .is_some_and(|type_decl| { + type_decl.is_alias() + && matches!( + type_decl.get_alias_ref(), + Some(LuaType::Ref(_) | LuaType::Generic(_)) + ) + }); + let contain_tpl = type_ref.contain_tpl(); + + if !contain_tpl && !has_alias_chain { + return type_ref; + } + + let mapper = TypeMapper::empty(); + + if contain_tpl { + if has_alias_chain { + let (current_type, current_id) = match &type_ref { + LuaType::Generic(generic) => { + let params = generic + .get_params() + .iter() + .map(|param| { + instantiate_type_generic_full( + db, + param, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ) + }) + .collect::>(); + ( + LuaType::Generic( + LuaGenericType::new(generic.get_base_type_id(), params).into(), + ), + generic.get_base_type_id(), + ) + } + LuaType::Ref(type_id) => (LuaType::Ref(type_id.clone()), type_id.clone()), + _ => { + return instantiate_type_generic_full( + db, + &type_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ); + } + }; + + return instantiate_doc_alias_chain(db, current_type, current_id, &mapper); + } + return instantiate_type_generic_full( + db, + &type_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ); + } + + if has_alias_chain { + let instantiated = instantiate_type_generic(db, &type_ref, &mapper); + if !matches!(instantiated, LuaType::Any | LuaType::Unknown) { + return instantiated; + } + } + + type_ref +} + +fn instantiate_doc_alias_chain( + db: &DbIndex, + mut current_type: LuaType, + mut current_id: LuaTypeDeclId, + mapper: &TypeMapper, +) -> LuaType { + let mut visited = HashSet::new(); + loop { + if !visited.insert(current_id.clone()) { + return current_type; + } + + let Some(type_decl) = db.get_type_index().get_type_decl(¤t_id) else { + return current_type; + }; + if !type_decl.is_alias() { + return current_type; + } + + let Some(origin) = type_decl.get_alias_ref() else { + return current_type; + }; + let next_id = match origin { + LuaType::Ref(type_id) => type_id.clone(), + LuaType::Generic(generic) => generic.get_base_type_id(), + _ => return current_type, + }; + + let params = match ¤t_type { + LuaType::Generic(generic) => generic.get_params().clone(), + LuaType::Ref(_) => Vec::new(), + _ => return current_type, + }; + let alias_mapper = TypeMapper::from_alias(db, params, ¤t_id); + let alias_mapper = TypeMapper::merge(Some(alias_mapper), mapper.clone()); + + current_type = match origin { + LuaType::Generic(generic) => { + let params = generic + .get_params() + .iter() + .map(|param| { + instantiate_type_generic_full( + db, + param, + &alias_mapper, + None, + TplResolvePolicy::PreserveTplRef, + ) + }) + .collect::>(); + LuaType::Generic(LuaGenericType::new(next_id.clone(), params).into()) + } + LuaType::Ref(type_id) => LuaType::Ref(type_id.clone()), + _ => return current_type, + }; + current_id = next_id; + } +} + fn bind_type_to_owner( analyzer: &mut DocAnalyzer, tag: &impl LuaAstNode, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index d6605830b..0d829df86 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -2,9 +2,8 @@ use emmylua_parser::{LuaAstToken, LuaExpr, LuaForRangeStat}; use crate::{ DbIndex, InferFailReason, LuaDeclId, LuaInferCache, LuaOperatorMetaMethod, LuaType, - LuaTypeCache, TplContext, TypeOps, TypeSubstitutor, VariadicType, - compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_doc_function, - tpl_pattern_match_args, + LuaTypeCache, TypeOps, VariadicType, compilation::analyzer::unresolve::UnResolveIterVar, + infer_expr, instantiate_doc_function_by_arg_types, }; use super::LuaAnalyzer; @@ -144,28 +143,8 @@ pub fn infer_for_range_iter_expr_func( let Some(status_param) = status_param else { return Ok(doc_function.get_variadic_ret()); }; - let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: None, - }; - let params = doc_function - .get_params() - .iter() - .map(|(_, opt_ty)| opt_ty.clone().unwrap_or(LuaType::Any)) - .collect::>(); - - tpl_pattern_match_args(&mut context, ¶ms, &[status_param])?; - - let instantiate_func = if let LuaType::DocFunction(f) = - instantiate_doc_function(db, &doc_function, &substitutor) - { - f - } else { - doc_function - }; + let instantiate_func = + instantiate_doc_function_by_arg_types(db, cache, &doc_function, &[status_param])?; Ok(instantiate_func.get_variadic_ret()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 34aaf8186..b055e22b6 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -1,8 +1,9 @@ use emmylua_parser::{ BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr, LuaIndexKey, - LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaTableExpr, LuaTableField, LuaVarExpr, - PathTrait, + LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaSyntaxId, LuaTableExpr, LuaTableField, + LuaVarExpr, PathTrait, }; +use hashbrown::HashSet; use crate::{ InFiled, InferFailReason, LuaBuiltinAttributeKind, LuaLspOptimizationCode, LuaMemberKey, @@ -47,11 +48,13 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) break; }; + pre_analyze_call_arg_table_fields(analyzer, &expr); + match analyzer.infer_expr(&expr) { Ok(expr_type) => { let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); let decl_id = LuaDeclId::new(analyzer.file_id, position); - // 当`call`参数包含表时, 表可能未被分析, 需要延迟 + // 当表达式中存在带表参数的调用时, 表可能尚未完成预分析, 需要延迟 if let LuaType::Instance(instance) = &expr_type && instance.get_base().is_unknown() && call_expr_has_effect_table_arg(&expr).is_some() @@ -162,17 +165,7 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) } fn call_expr_has_effect_table_arg(expr: &LuaExpr) -> Option<()> { - if let LuaExpr::CallExpr(call_expr) = expr { - let args_list = call_expr.get_args_list()?; - for arg in args_list.get_args() { - if let LuaExpr::TableExpr(table_expr) = arg - && !table_expr.is_empty() - { - return Some(()); - } - } - } - None + expr_has_effect_table_call_arg(expr.clone()) } fn get_var_owner(analyzer: &mut LuaAnalyzer, var: LuaVarExpr) -> LuaTypeOwner { @@ -316,6 +309,8 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta continue; } + pre_analyze_call_arg_table_fields(analyzer, expr); + let expr_type = match analyzer.infer_expr(expr) { Ok(expr_type) => expr_type.get_result_slot_type(0).unwrap_or(expr_type), Err(InferFailReason::None) => LuaType::Unknown, @@ -503,12 +498,16 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> return Some(()); } + let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); + if is_table_field_already_analyzed(analyzer, &field, &member_id) { + return Some(()); + } + if let Some(field_key) = field.get_field_key() { if let LuaIndexKey::Expr(_) = &field_key { // Decl analysis leaves `[expr] = value` fields unresolved. If the key // already resolves here, materialize the member now. let db = &mut *analyzer.db; - let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); if db.get_member_index().get_member(&member_id).is_none() { let cache = analyzer .context @@ -536,8 +535,16 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> } } + if analyzer + .db + .get_type_index() + .get_type_cache(&member_id.into()) + .is_some() + { + return Some(()); + } let value_expr = field.get_value_expr()?; - let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); + let value_type = match analyzer.infer_expr(&value_expr.clone()) { Ok(value_type) => match value_type { LuaType::Def(ref_id) => LuaType::Ref(ref_id), @@ -564,6 +571,30 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> Some(()) } +fn is_table_field_already_analyzed( + analyzer: &LuaAnalyzer, + field: &LuaTableField, + member_id: &LuaMemberId, +) -> bool { + if analyzer + .db + .get_type_index() + .get_type_cache(&(*member_id).into()) + .is_none() + { + return false; + } + + match field.get_field_key() { + Some(LuaIndexKey::Expr(_)) => analyzer + .db + .get_member_index() + .get_member(member_id) + .is_some(), + _ => true, + } +} + fn special_assign_pattern( analyzer: &mut LuaAnalyzer, type_owner: LuaTypeOwner, @@ -627,3 +658,147 @@ fn get_delayed_definition_decl_id( } Some(decl_id) } + +fn pre_analyze_call_arg_table_fields(analyzer: &mut LuaAnalyzer, expr: &LuaExpr) { + let mut analyzed_fields = HashSet::new(); + pre_analyze_nested_table_fields(analyzer, expr.clone(), 0, &mut analyzed_fields); +} + +fn pre_analyze_nested_table_fields( + analyzer: &mut LuaAnalyzer, + expr: LuaExpr, + depth: usize, + analyzed_fields: &mut HashSet, +) { + if depth >= 250 { + return; + } + + let next_depth = depth + 1; + match expr { + LuaExpr::CallExpr(call_expr) => { + if let Some(prefix_expr) = call_expr.get_prefix_expr() { + pre_analyze_nested_table_fields(analyzer, prefix_expr, next_depth, analyzed_fields); + } + + if let Some(args_list) = call_expr.get_args_list() { + for arg in args_list.get_args() { + pre_analyze_nested_table_fields(analyzer, arg, next_depth, analyzed_fields); + } + } + } + LuaExpr::TableExpr(table_expr) => { + for field in table_expr.get_fields() { + if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() { + pre_analyze_nested_table_fields( + analyzer, + key_expr, + next_depth, + analyzed_fields, + ); + } + + if let Some(value_expr) = field.get_value_expr() { + pre_analyze_nested_table_fields( + analyzer, + value_expr, + next_depth, + analyzed_fields, + ); + } + + if analyzed_fields.insert(field.get_syntax_id()) { + analyze_table_field(analyzer, field.clone()); + } + } + } + LuaExpr::BinaryExpr(binary_expr) => { + if let Some((left, right)) = binary_expr.get_exprs() { + pre_analyze_nested_table_fields(analyzer, left, next_depth, analyzed_fields); + pre_analyze_nested_table_fields(analyzer, right, next_depth, analyzed_fields); + } + } + LuaExpr::UnaryExpr(unary_expr) => { + if let Some(inner_expr) = unary_expr.get_expr() { + pre_analyze_nested_table_fields(analyzer, inner_expr, next_depth, analyzed_fields); + } + } + LuaExpr::ParenExpr(paren_expr) => { + if let Some(inner_expr) = paren_expr.get_expr() { + pre_analyze_nested_table_fields(analyzer, inner_expr, next_depth, analyzed_fields); + } + } + LuaExpr::IndexExpr(index_expr) => { + if let Some(prefix_expr) = index_expr.get_prefix_expr() { + pre_analyze_nested_table_fields(analyzer, prefix_expr, next_depth, analyzed_fields); + } + + if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { + pre_analyze_nested_table_fields(analyzer, key_expr, next_depth, analyzed_fields); + } + } + LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => {} + } +} + +fn expr_has_effect_table_call_arg(expr: LuaExpr) -> Option<()> { + match expr { + LuaExpr::CallExpr(call_expr) => { + if let Some(prefix_expr) = call_expr.get_prefix_expr() + && expr_has_effect_table_call_arg(prefix_expr).is_some() + { + return Some(()); + } + + let args_list = call_expr.get_args_list()?; + for arg in args_list.get_args() { + if let LuaExpr::TableExpr(table_expr) = &arg + && !table_expr.is_empty() + { + return Some(()); + } + + if expr_has_effect_table_call_arg(arg).is_some() { + return Some(()); + } + } + None + } + LuaExpr::TableExpr(table_expr) => { + for field in table_expr.get_fields() { + if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() + && expr_has_effect_table_call_arg(key_expr).is_some() + { + return Some(()); + } + + if let Some(value_expr) = field.get_value_expr() + && expr_has_effect_table_call_arg(value_expr).is_some() + { + return Some(()); + } + } + None + } + LuaExpr::BinaryExpr(binary_expr) => { + let (left, right) = binary_expr.get_exprs()?; + expr_has_effect_table_call_arg(left).or_else(|| expr_has_effect_table_call_arg(right)) + } + LuaExpr::UnaryExpr(unary_expr) => expr_has_effect_table_call_arg(unary_expr.get_expr()?), + LuaExpr::ParenExpr(paren_expr) => expr_has_effect_table_call_arg(paren_expr.get_expr()?), + LuaExpr::IndexExpr(index_expr) => { + if let Some(prefix_expr) = index_expr.get_prefix_expr() + && expr_has_effect_table_call_arg(prefix_expr).is_some() + { + return Some(()); + } + + if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { + return expr_has_effect_table_call_arg(key_expr); + } + + None + } + LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => None, + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs index abe30c473..ea0a2f0ca 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs @@ -4,7 +4,7 @@ use smol_str::SmolStr; use crate::{ InFiled, InferFailReason, InferGuardRef, LuaInferCache, LuaInstanceType, LuaMemberId, - LuaMemberOwner, LuaOperatorOwner, TypeOps, TypeSubstitutor, check_type_compact, + LuaMemberOwner, LuaOperatorOwner, TypeMapper, TypeOps, check_type_compact, db_index::{ DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, @@ -355,7 +355,7 @@ fn index_generic_members_from_super_generics( db: &DbIndex, cache: &mut LuaInferCache, type_decl_id: &LuaTypeDeclId, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, index_expr: LuaIndexMemberExpr, infer_guard: &InferGuardRef, deep_guard: &mut DeepLevel, @@ -370,7 +370,7 @@ fn index_generic_members_from_super_generics( let type_decl_id = type_decl.get_id(); if let Some(super_types) = type_index.get_super_types(&type_decl_id) { super_types.iter().find_map(|super_type| { - let super_type = instantiate_type_generic(db, super_type, substitutor); + let super_type = instantiate_type_generic(db, super_type, mapper); find_function_type_by_member_key( db, cache, @@ -397,13 +397,13 @@ fn find_generic_member( let base_type = generic_type.get_base_type(); let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); if let LuaType::Ref(base_type_decl_id) = &base_type { let result = index_generic_members_from_super_generics( db, cache, base_type_decl_id, - &substitutor, + &mapper, index_expr.clone(), infer_guard, deep_guard, @@ -422,7 +422,7 @@ fn find_generic_member( deep_guard, )?; - Ok(instantiate_type_generic(db, &member_type, &substitutor)) + Ok(instantiate_type_generic(db, &member_type, &mapper)) } fn find_instance_member_decl_type( @@ -772,17 +772,17 @@ fn find_member_by_index_generic( return Err(InferFailReason::None); }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index .get_type_decl(&type_decl_id) .ok_or(InferFailReason::None)?; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&mapper)) { return find_function_type_by_operator( db, cache, - &instantiate_type_generic(db, &origin_type, &substitutor), + &instantiate_type_generic(db, &origin_type, &mapper), index_expr.clone(), &infer_guard.fork(), deep_guard, @@ -801,9 +801,9 @@ fn find_member_by_index_generic( .get_operator(index_operator_id) .ok_or(InferFailReason::None)?; let operand = index_operator.get_operand(db); - let instianted_operand = instantiate_type_generic(db, &operand, &substitutor); + let instianted_operand = instantiate_type_generic(db, &operand, &mapper); let return_type = - instantiate_type_generic(db, &index_operator.get_result(db)?, &substitutor); + instantiate_type_generic(db, &index_operator.get_result(db)?, &mapper); let result = find_index_metamethod(db, cache, &member_key, &instianted_operand, &return_type); @@ -826,7 +826,7 @@ fn find_member_by_index_generic( let result = find_function_type_by_operator( db, cache, - &instantiate_type_generic(db, &super_type, &substitutor), + &instantiate_type_generic(db, &super_type, &mapper), index_expr.clone(), &infer_guard.fork(), deep_guard, diff --git a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs index 481e06059..f75e18d6f 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs @@ -105,6 +105,38 @@ mod test { assert_eq!(b, LuaType::Integer); } + #[test] + fn test_issue_490_split_iterator_state() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T: table, K, V + ---@param t T + ---@return fun(table: table, index?: K):K, V + ---@return T + local function spairs(t) + return next, t + end + + --- @type table + local t = { a = 1, b = 2, c = 3 } + + local iter, state = spairs(t) + + for name, value in iter, state do + a = name + b = value + end + "#, + ); + + let a = ws.expr_ty("a"); + let b = ws.expr_ty("b"); + assert_eq!(a, LuaType::String); + assert_eq!(b, LuaType::Integer); + } + #[test] fn test_enum_key_pairs() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -211,6 +243,62 @@ mod test { assert_eq!(value_union.into_set(), expected_values); } + #[test] + fn test_pairs_metamethod_extracts_iterator_from_multi_return() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class PairSource + local PairSource + + ---@return fun(): string, number + ---@return table + ---@return nil + function PairSource:__pairs() + end + + ---@type PairSource + local source + + for k, v in pairs(source) do + key_out = k + value_out = v + end + "#, + ); + + assert_eq!(ws.expr_ty("key_out"), LuaType::String); + assert_eq!(ws.expr_ty("value_out"), LuaType::Number); + } + + #[test] + fn test_pairs_metamethod_extracts_iterator_from_single_return() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class PairSource + local PairSource + + ---@return fun(): string + ---@return table + ---@return nil + function PairSource:__pairs() + end + + ---@type PairSource + local source + + for k, v in pairs(source) do + key_out = k + end + "#, + ); + + assert_eq!(ws.expr_ty("key_out"), LuaType::String); + } + #[test] fn test_issue_291() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs new file mode 100644 index 000000000..798c5e667 --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs @@ -0,0 +1,237 @@ +#[cfg(test)] +mod test { + use crate::{DiagnosticCode, VirtualWorkspace}; + + #[test] + fn test_builtin_pick_preserves_selected_properties() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinPickUser + ---@field name string + ---@field age number + ---@field email string + ---@field nickname? string + + ---@type Pick + local picked + PickedName = picked.name + PickedAge = picked.age + PickedNickname = picked.nickname + + ---@type Pick + local pickedAll + PickedAllEmail = pickedAll.email + + ---@type Pick<{id: integer, enabled: boolean, label?: string}, "id" | "label"> + local pickedLiteral + PickedLiteralId = pickedLiteral.id + PickedLiteralLabel = pickedLiteral.label + "#, + ); + + assert_eq!(ws.expr_ty("PickedName"), ws.ty("string")); + assert_eq!(ws.expr_ty("PickedAge"), ws.ty("number")); + assert_eq!(ws.expr_ty("PickedNickname"), ws.ty("string?")); + assert_eq!(ws.expr_ty("PickedAllEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("PickedLiteralId"), ws.ty("integer")); + assert_eq!(ws.expr_ty("PickedLiteralLabel"), ws.ty("string?")); + } + + #[test] + fn test_builtin_pick_matches_ts6_key_constraint() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinPickConstraintUser + ---@field name string + ---@field age number + "#, + ); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Pick + local picked + local name = picked.name + "# + )); + } + + #[test] + fn test_builtin_pick_empty_keyof_domain_converges() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinEmptyPickClass + + ---@type Pick<{}, keyof {}> + local pickedEmptyObject + PickedEmptyObjectMissing = pickedEmptyObject.missing + + ---@type Pick + local pickedEmptyClass + PickedEmptyClassMissing = pickedEmptyClass.missing + "#, + ); + + assert_eq!(ws.expr_ty("PickedEmptyObjectMissing"), ws.ty("nil")); + assert_eq!(ws.expr_ty("PickedEmptyClassMissing"), ws.ty("nil")); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick<{}, keyof {}> + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Pick<{}, keyof {}> + local picked + local missing = picked.missing + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@class BuiltinEmptyPickDiagnosticClass + + ---@type Pick + local picked + local missing = picked.missing + "# + )); + } + + #[test] + fn test_builtin_omit_removes_properties() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinOmitUser + ---@field name string + ---@field age number + ---@field email string + ---@field nickname? string + + ---@type Omit + local omitted + OmittedName = omitted.name + OmittedAge = omitted.age + OmittedNickname = omitted.nickname + OmittedEmail = omitted.email + + ---@type Pick> + local pickedWithoutEmail + PickedWithoutEmailEmail = pickedWithoutEmail.email + + ---@type Omit<{id: integer, enabled: boolean, label?: string}, "enabled"> + local omittedLiteral + OmittedLiteralId = omittedLiteral.id + OmittedLiteralLabel = omittedLiteral.label + "#, + ); + + assert_eq!(ws.expr_ty("OmittedName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmittedAge"), ws.ty("number")); + assert_eq!(ws.expr_ty("OmittedNickname"), ws.ty("string?")); + assert_eq!(ws.expr_ty("PickedWithoutEmailEmail"), ws.ty("nil")); + assert_eq!(ws.expr_ty("OmittedEmail"), ws.ty("nil")); + assert_eq!(ws.expr_ty("OmittedLiteralId"), ws.ty("integer")); + assert_eq!(ws.expr_ty("OmittedLiteralLabel"), ws.ty("string?")); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local email = omitted.email + "# + )); + } + + #[test] + fn test_builtin_omit_matches_ts6_keyof_any_behavior() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinOmitKeyUser + ---@field name string + ---@field age number + ---@field email string + + ---@type Omit + local omitMissing + OmitMissingName = omitMissing.name + OmitMissingEmail = omitMissing.email + + ---@type Omit + local omitNever + OmitNeverName = omitNever.name + OmitNeverEmail = omitNever.email + + ---@type Omit + local omitAll + OmitAllName = omitAll.name + "#, + ); + + assert_eq!(ws.expr_ty("OmitMissingName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitMissingEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitNeverName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitNeverEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitAllName"), ws.ty("nil")); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Omit + local omitted + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local name = omitted.name + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local name = omitted.name + "# + )); + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs index 5eef7210c..226513f3e 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs @@ -47,6 +47,57 @@ mod test { assert_eq!(ws.humanize_type(a_ty), "string"); } + #[test] + fn test_object_literal_infer_nested_call_argument() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ExtractX T extends { x: infer X } and X or never + + ---@generic T + ---@param value T + ---@return ExtractX + function extractX(value) end + + ---@generic T + ---@param value T + ---@return T + function identity(value) end + + A = identity(extractX({ x = 1 })) + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(a_ty), "integer"); + } + + #[test] + fn test_object_literal_infer_nested_call_inside_table_field() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ExtractX T extends { x: infer X } and X or never + ---@alias ExtractInner T extends { inner: infer I } and I or never + + ---@generic T + ---@param value T + ---@return ExtractX + function extractX(value) end + + ---@generic T + ---@param value T + ---@return ExtractInner + function extractInner(value) end + + B = extractInner({ inner = extractX({ x = 1 }) }) + "#, + ); + + let b_ty = ws.expr_ty("B"); + assert_eq!(ws.humanize_type(b_ty), "integer"); + } + #[test] fn test_object_literal_infer_from_class() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 3a1b462b9..e562bc776 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -3,8 +3,8 @@ mod test { use emmylua_parser::LuaClosureExpr; use crate::{ - DiagnosticCode, LuaSignatureId, LuaType, LuaTypeDeclId, VirtualWorkspace, - complete_type_generic_args, + DiagnosticCode, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId, TypeMapper, + TypeMapperValue, VirtualWorkspace, complete_type_generic_args, instantiate_type_generic, }; #[test] @@ -139,7 +139,7 @@ mod test { "#, ); let a_ty = ws.expr_ty("a"); - assert_eq!(a_ty, ws.ty("unknown")); + assert_eq!(a_ty, LuaType::Unknown); } // Currently fails: @@ -412,6 +412,172 @@ mod test { )); } + #[test] + fn test_keyof_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Keys keyof T + ---@alias ForwardKeys Keys + + ---@param key "a" | "b" + function accept(key) end + "#, + ); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type ForwardKeys<{ a: string, b: number }> + local key + accept(key) + "# + )); + } + + #[test] + fn test_mapped_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K]; } + ---@alias ForwardCopy Copy + + ---@type ForwardCopy<{ a: string, b: number }> + local copy + + A = copy.a + B = copy.b + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + assert_eq!(ws.expr_ty("B"), ws.ty("number")); + } + + #[test] + fn test_mapped_unresolved_key_domain_preserves_residual() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K]; } + + ---@generic T + ---@param value Copy + ---@return Copy + function keep(value) end + + ---@type Copy<{ a: string }> + local concrete + + A = keep(concrete).a + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + } + + #[test] + fn test_alias_argument_binding_ignores_shadowing_function_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Box fun(x: T): T + + ---@type Box + local f + + Result = f(1) + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Box + local f + + f(1) + "# + )); + } + + #[test] + fn test_alias_argument_binding_ignores_shadowing_mapped_key() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Shadow { [T in keyof { a: string }]: T; } + + ---@type Shadow + local value + + A = value.a + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty(r#""a""#)); + } + + #[test] + fn test_conditional_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Extract T extends U and T or never + ---@alias KeepA Extract + ---@alias Forward KeepA + "#, + ); + + let generic_ty = ws.ty(r#"Forward<"a" | "b">"#); + let mapper = TypeMapper::empty(); + let instantiated = instantiate_type_generic(ws.get_db_mut(), &generic_ty, &mapper); + assert_eq!(instantiated, ws.ty(r#""a""#)); + } + + #[test] + fn test_nested_mapped_conditional_alias_residual_resolves() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Wrapper + ---@alias UnwrapFields { [K in keyof T]: T[K] extends Wrapper and U or T[K]; } + ---@alias Forward UnwrapFields + + ---@type Forward<{ a: Wrapper, b: number }> + local value + + A = value.a + B = value.b + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + assert_eq!(ws.expr_ty("B"), ws.ty("number")); + } + + #[test] + fn test_recursive_alias_instantiation_budget_falls_back_safely() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Loop Loop + ---@alias Forward Loop + + ---@type Forward + local value + + Value = value + "#, + ); + + let value_ty = ws.expr_ty("Value"); + assert_eq!(ws.humanize_type(value_ty), "Forward"); + } + #[test] fn test_issue_787() { let mut ws = VirtualWorkspace::new(); @@ -494,9 +660,12 @@ mod test { A, B, C = f1(1, "2", true) "#, )); - assert_eq!(ws.expr_ty("A"), ws.ty("integer")); - assert_eq!(ws.expr_ty("B"), ws.ty("string")); - assert_eq!(ws.expr_ty("C"), ws.ty("boolean")); + let a_ty = ws.expr_ty("A"); + let b_ty = ws.expr_ty("B"); + let c_ty = ws.expr_ty("C"); + assert_eq!(ws.humanize_type(a_ty), "1"); + assert_eq!(ws.humanize_type(b_ty), "\"2\""); + assert_eq!(ws.humanize_type(c_ty), "true"); } { ws.def( @@ -533,7 +702,8 @@ mod test { G, H = f3(1, "2") "#, )); - assert_eq!(ws.expr_ty("G"), ws.ty("integer")); + let g_ty = ws.expr_ty("G"); + assert_eq!(ws.humanize_type(g_ty), "1"); assert_eq!(ws.expr_ty("H"), ws.ty("any")); } @@ -681,7 +851,7 @@ mod test { "#, )); let result_ty = ws.expr_ty("result"); - assert_eq!(ws.humanize_type(result_ty), "string"); + assert_eq!(ws.humanize_type(result_ty), "\"\""); } #[test] @@ -699,7 +869,7 @@ mod test { ); let result_ty = ws.expr_ty("result"); - assert_eq!(ws.humanize_type(result_ty), "integer"); + assert_eq!(ws.humanize_type(result_ty), "1"); } #[test] @@ -764,6 +934,7 @@ mod test { .expect("Box generic params"); assert_eq!(box_params.len(), 1); assert_eq!(box_params[0].name.as_str(), "T"); + assert_eq!(box_params[0].tpl_id, Some(GenericTplId::Type(0))); let box_default = box_params[0] .default_type .clone() @@ -779,6 +950,7 @@ mod test { .expect("Optional generic params"); assert_eq!(optional_params.len(), 1); assert_eq!(optional_params[0].name.as_str(), "T"); + assert_eq!(optional_params[0].tpl_id, Some(GenericTplId::Type(0))); let optional_default = optional_params[0] .default_type .clone() @@ -816,6 +988,35 @@ mod test { assert_eq!(ws.humanize_type(default_type), "string"); } + #[test] + fn test_signature_instantiation_keeps_template_shapes() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@generic T + ---@param x T + ---@return T + local function id(x) + return x + end + "#, + ); + + let closure = ws.get_node::(file_id); + let signature_id = LuaSignatureId::from_closure(file_id, &closure); + let mapper = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let instantiated = instantiate_type_generic( + ws.analysis.compilation.get_db(), + &LuaType::Signature(signature_id), + &mapper, + ); + + assert_eq!(ws.humanize_type(instantiated), "fun(x: string) -> string"); + } + #[test] fn test_bare_generic_type_uses_default() { let mut ws = VirtualWorkspace::new(); @@ -854,6 +1055,83 @@ mod test { assert_eq!(ws.humanize_type(value_ty), "Base"); } + #[test] + fn test_doc_type_binding_instantiates_alias_origin() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Box { value: T } + ---@alias Forward Box + + ---@type Forward + local value + + result = value + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("{ value: string }")); + } + + #[test] + fn test_doc_type_binding_preserves_residual_aliases() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Extract T extends U and T or never + ---@alias KeepA Extract + ---@alias ForwardConditional KeepA + + ---@alias Copy { [K in keyof T]: T[K] } + ---@alias ForwardMapped Copy + + ---@generic T + local function f() + ---@type ForwardConditional + local cond + + ---@type ForwardMapped + local mapped + + cond_result = cond + mapped_result = mapped + end + "#, + ); + + let cond_ty = ws.expr_ty("cond_result"); + let cond_desc = ws.humanize_type(cond_ty); + assert_eq!(cond_desc, r#"Extract"#); + + let mapped_ty = ws.expr_ty("mapped_result"); + assert_eq!(ws.humanize_type(mapped_ty), "Copy"); + } + + #[test] + fn test_detailed_humanize_shows_mapped_alias_body() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K] } + ---@alias ForwardMapped Copy + + ---@generic T + local function f() + ---@type ForwardMapped + local mapped + + mapped_result = mapped + end + "#, + ); + + let mapped_ty = ws.expr_ty("mapped_result"); + assert_eq!(ws.humanize_type(mapped_ty.clone()), "Copy"); + let mapped_desc = ws.humanize_type_detailed(mapped_ty); + assert!(mapped_desc.starts_with("Copy = { [K in keyof T]: ")); + } + #[test] fn test_partial_generic_type_fills_trailing_default() { let mut ws = VirtualWorkspace::new(); @@ -1075,7 +1353,7 @@ mod test { let explicit_result = ws.expr_ty("ExplicitResult"); assert_eq!(ws.humanize_type(explicit_result), "number"); let inferred_result = ws.expr_ty("InferredResult"); - assert_eq!(ws.humanize_type(inferred_result), "integer"); + assert_eq!(ws.humanize_type(inferred_result), "1"); } #[test] @@ -1217,7 +1495,7 @@ mod test { } #[test] - fn test_constant_decay() { + fn test_plain_tpl_literal_key_inference_widens_through_finalize() { let mut ws = VirtualWorkspace::new(); ws.def( r#" @@ -1250,6 +1528,220 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "integer"); } + #[test] + fn test_const_tpl_candidate_preserves_literal_through_plain_return() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function keep_const(value) + end + + result = keep_const("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_const_tpl_candidate_preserves_structural_literals() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function keep_const(value) + end + + table_result = keep_const({ kind = "mode", count = 1 }) + table_kind = table_result.kind + table_count = table_result.count + + ---@type { kind: "mode", count: 1 } + local object + + object_result = keep_const(object) + object_kind = object_result.kind + object_count = object_result.count + + ---@type ["mode", 1] + local tuple + + tuple_result = keep_const(tuple) + tuple_first = tuple_result[1] + tuple_second = tuple_result[2] + "#, + ); + + let table_kind = ws.expr_ty("table_kind"); + let table_count = ws.expr_ty("table_count"); + let object_kind = ws.expr_ty("object_kind"); + let object_count = ws.expr_ty("object_count"); + let tuple_first = ws.expr_ty("tuple_first"); + let tuple_second = ws.expr_ty("tuple_second"); + + assert_eq!(ws.humanize_type(table_kind), "\"mode\""); + assert_eq!(ws.humanize_type(table_count), "1"); + assert_eq!(ws.humanize_type(object_kind), "\"mode\""); + assert_eq!(ws.humanize_type(object_count), "1"); + assert_eq!(ws.humanize_type(tuple_first), "\"mode\""); + assert_eq!(ws.humanize_type(tuple_second), "1"); + } + + #[test] + fn test_plain_tpl_top_level_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + result = id("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_transparent_alias_top_level_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Id T + + ---@generic T + ---@param value T + ---@return Id + function id(value) + end + + result = id("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_transparent_alias_root_union_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Id T + + ---@generic T + ---@param value T + ---@return Id|nil + function maybe(value) + end + + result = maybe("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), r#""mode"?"#); + } + + #[test] + fn test_plain_tpl_root_union_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T|nil + function maybe(value) + end + + result = maybe("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), r#""mode"?"#); + } + + #[test] + fn test_plain_tpl_top_level_return_preserves_primitive_literal_union() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + ---@alias Choice "left" | "right" + + ---@type Choice + local choice + + result = id(choice) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("\"left\" | \"right\"")); + } + + #[test] + fn test_primitive_constraint_preserves_literal_candidate() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T: string + ---@param value T + ---@return T + function constrained(value) + end + + result = constrained("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_finalized_table_const_self_reference_widens_without_recursing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + local t = { kind = "mode" } + t.self = t + + result = id(t) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("{ kind: string, self: table }")); + } + #[test] fn test_extends_true() { let mut ws = VirtualWorkspace::new(); @@ -1475,6 +1967,26 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "number"); } + #[test] + fn test_distributed_function_generic_conditional_return_filters_union_members() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T extends string and T or never + function extractString(value) end + + ---@type string|integer + local value + + A = extractString(value) + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + } + #[test] fn test_union_never() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 00307054f..14a0f93c7 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -395,6 +395,34 @@ mod test { ); } + #[test] + fn test_call_arg_nested_table_expr_key_doc_const() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + local function id(value) + return value + end + + ---@type 'field' + local key = "field" + local t = id({ nested = { [key] = 1 } }).nested + value = t[key] + "#, + ); + + let value_ty = ws.expr_ty("value"); + assert!( + matches!(value_ty, LuaType::Integer | LuaType::IntegerConst(_)), + "expected integer type, got {:?}", + value_ty + ); + } + #[test] fn test_union_member_access_preserves_never() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/mod.rs b/crates/emmylua_code_analysis/src/compilation/test/mod.rs index a6c8629fd..3b6807bde 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/mod.rs @@ -10,6 +10,7 @@ mod decl_test; mod diagnostic_disable_test; mod flow; mod for_range_var_infer_test; +mod generic_builtin; mod generic_infer_test; mod generic_test; mod infer_str_tpl_test; diff --git a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs index 47ef6dfb4..fc4fb062a 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs @@ -35,7 +35,7 @@ mod test { let mut ws = VirtualWorkspace::new_with_init_std_lib(); ws.def( r#" - ---@overload fun(t: T): std.Unpack + ---@overload fun(t: std.ConstTpl): std.Unpack ---@overload fun(t: number): number local function f(t) end @@ -55,7 +55,7 @@ mod test { ws.def( r#" ---@class Obj - ---@field unpack (fun(self: Obj, t: T): std.Unpack) | (fun(self: Obj, t: number): number) + ---@field unpack (fun(self: Obj, t: std.ConstTpl): std.Unpack) | (fun(self: Obj, t: number): number) local Obj = {} a, b = Obj:unpack({ 1, 2 }) @@ -75,7 +75,7 @@ mod test { local Obj = {} ---@generic T - ---@param t T + ---@param t std.ConstTpl ---@return std.Unpack function Obj:unpack(t) end @@ -130,6 +130,44 @@ mod test { assert_eq!(ws.humanize_type(b_ty), "number"); } + #[test] + fn test_unpack_alias_call_uses_uninferred_generic_default() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@generic T = [string, number] + ---@return std.Unpack + function f() + end + + a, b = f() + "#, + ); + + let a_ty = ws.expr_ty("a"); + let b_ty = ws.expr_ty("b"); + assert_eq!(ws.humanize_type(a_ty), "string"); + assert_eq!(ws.humanize_type(b_ty), "number"); + } + + #[test] + fn test_unpack_alias_call_uses_uninferred_generic_constraint() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@generic T: string[] + ---@return std.Unpack + function f() + end + + a = f() + "#, + ); + + let a_ty = ws.expr_ty("a"); + assert_eq!(ws.humanize_type(a_ty), "string?"); + } + #[test] fn test_issue_484() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs index 1a66b2031..08978d8be 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs @@ -1,10 +1,11 @@ use smol_str::SmolStr; -use crate::{LuaAttributeUse, LuaType}; +use crate::{GenericTplId, LuaAttributeUse, LuaType}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GenericParam { pub name: SmolStr, + pub tpl_id: Option, pub type_constraint: Option, pub default_type: Option, pub attributes: Option>, @@ -19,9 +20,15 @@ impl GenericParam { ) -> Self { Self { name, + tpl_id: None, type_constraint, default_type, attributes, } } + + pub fn with_tpl_id(mut self, tpl_id: Option) -> Self { + self.tpl_id = tpl_id; + self + } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index f26732263..4bcd26cef 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -5,9 +5,9 @@ use itertools::Itertools; use crate::{ AsyncState, DbIndex, LuaAliasCallType, LuaConditionalType, LuaFunctionType, LuaGenericType, - LuaIntersectionType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaSignatureId, - LuaStringTplType, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, TypeSubstitutor, - VariadicType, + LuaIntersectionType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, + LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, + TypeMapper, VariadicType, }; use super::{LuaAliasCallKind, LuaMultiLineUnion}; @@ -207,6 +207,7 @@ impl<'a> TypeHumanizer<'a> { LuaType::MultiLineUnion(multi_union) => { self.write_multi_line_union_type(multi_union, w) } + LuaType::Mapped(mapped) => self.write_mapped_type(mapped, w), LuaType::TypeGuard(inner) => { w.write_str("TypeGuard<")?; let saved = self.level; @@ -717,8 +718,10 @@ impl<'a> TypeHumanizer<'a> { return Ok(()); } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().clone()); - if let Some(origin_type) = type_decl.get_alias_origin(self.db, Some(&substitutor)) { + let mapper = TypeMapper::from_type_array(generic.get_params().clone()); + if let Some(origin_type) = + type_decl.get_alias_origin_preserve_tpl(self.db, Some(&mapper)) + { w.write_str(" = ")?; let saved = self.level; self.level = self.child_level(); @@ -1025,6 +1028,68 @@ impl<'a> TypeHumanizer<'a> { } } + // ─── Mapped ───────────────────────────────────────────────────── + + fn write_mapped_type(&mut self, mapped: &LuaMappedType, w: &mut W) -> fmt::Result { + if self.level == RenderLevel::Minimal { + return w.write_str("{...}"); + } + + w.write_str("{ ")?; + if mapped.is_readonly { + w.write_str("readonly ")?; + } + + let saved = self.level; + self.level = self.child_level(); + + w.write_char('[')?; + w.write_str(mapped.param.1.name.as_str())?; + w.write_str(" in ")?; + self.write_mapped_constraint(mapped.param.1.type_constraint.as_ref(), w)?; + w.write_char(']')?; + if mapped.is_optional { + w.write_char('?')?; + } + w.write_str(": ")?; + self.write_mapped_value(&mapped.value, w)?; + + self.level = saved; + w.write_str(" }") + } + + fn write_mapped_constraint( + &mut self, + constraint: Option<&LuaType>, + w: &mut W, + ) -> fmt::Result { + match constraint { + Some(LuaType::Call(alias_call)) + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 => + { + w.write_str("keyof ")?; + self.write_type(&alias_call.get_operands()[0], w) + } + Some(ty) => self.write_type(ty, w), + None => w.write_str("unknown"), + } + } + + fn write_mapped_value(&mut self, value: &LuaType, w: &mut W) -> fmt::Result { + if let LuaType::Call(alias_call) = value + && alias_call.get_call_kind() == LuaAliasCallKind::Index + && alias_call.get_operands().len() == 2 + { + self.write_type(&alias_call.get_operands()[0], w)?; + w.write_char('[')?; + self.write_type(&alias_call.get_operands()[1], w)?; + return w.write_char(']'); + } + + self.write_type(value, w) + } + // ─── helper: write a table member (key: type) ─────────────────── fn write_table_member_field( diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs index ba8ef4a05..8b655303d 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs @@ -7,8 +7,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use smol_str::SmolStr; use crate::{ - DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TypeSubstitutor, db_index::WorkspaceId, - instantiate_type_generic, + DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TplResolvePolicy, TypeMapper, + db_index::WorkspaceId, instantiate_type_generic, instantiate_type_generic_full, }; use super::{LuaType, LuaUnionType}; @@ -130,17 +130,42 @@ impl LuaTypeDecl { .map(|idx| &self.id.get_name()[..idx]) } - pub fn get_alias_origin( + pub fn get_alias_origin(&self, db: &DbIndex, mapper: Option<&TypeMapper>) -> Option { + match &self.extra { + LuaTypeExtra::Alias { + origin: Some(origin), + } => { + let mapper = match mapper { + Some(mapper) => mapper, + None => return Some(origin.clone()), + }; + + let type_decl_id = self.get_id(); + if db + .get_type_index() + .get_generic_params(&type_decl_id) + .is_none() + { + return Some(origin.clone()); + } + + Some(instantiate_type_generic(db, origin, mapper)) + } + _ => None, + } + } + + pub fn get_alias_origin_preserve_tpl( &self, db: &DbIndex, - substitutor: Option<&TypeSubstitutor>, + mapper: Option<&TypeMapper>, ) -> Option { match &self.extra { LuaTypeExtra::Alias { origin: Some(origin), } => { - let substitutor = match substitutor { - Some(substitutor) => substitutor, + let mapper = match mapper { + Some(mapper) => mapper, None => return Some(origin.clone()), }; @@ -153,7 +178,13 @@ impl LuaTypeDecl { return Some(origin.clone()); } - Some(instantiate_type_generic(db, origin, substitutor)) + Some(instantiate_type_generic_full( + db, + origin, + mapper, + None, + TplResolvePolicy::PreserveTplRef, + )) } _ => None, } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs index ec63747de..69c5297ad 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs @@ -20,6 +20,15 @@ mod tests { ); } + #[test] + fn test_single_return_uses_result_slot_extraction() { + assert_eq!( + LuaType::String.get_result_slot_type(0), + Some(LuaType::String) + ); + assert_eq!(LuaType::String.get_result_slot_type(1), None); + } + #[test] fn test_deep_contain_tpl_uses_iterative_walk() { let mut ty = LuaType::TplRef( diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index cc5594eff..ae1ba0eb8 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -1,6 +1,6 @@ use emmylua_parser::{ LuaAst, LuaAstNode, LuaCallExpr, LuaClosureExpr, LuaComment, LuaDocGenericDeclList, - LuaDocTagAlias, LuaDocTagClass, LuaDocTagGeneric, LuaDocTagType, LuaDocType, + LuaDocGenericType, LuaDocTagAlias, LuaDocTagClass, LuaDocTagGeneric, LuaDocTagType, LuaDocType, }; use rowan::TextRange; use smol_str::SmolStr; @@ -13,8 +13,8 @@ use crate::semantic::{ use crate::{ DiagnosticCode, DocTypeInferContext, GenericTplId, LuaArrayType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, - LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, - TypeSubstitutor, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, + LuaTypeNode, LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, + TypeMapper, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, }; pub struct GenericConstraintMismatchChecker; @@ -55,7 +55,7 @@ fn check_call_expr( let Some(CallConstraintContext { params, args, - substitutor, + mapper, }) = build_call_constraint_context(semantic_model, &call_expr) else { return Some(()); @@ -76,7 +76,7 @@ fn check_call_expr( param_type, &args, false, - &substitutor, + &mapper, ); } @@ -617,55 +617,141 @@ fn check_doc_tag_type( let type_list = doc_tag_type.get_type_list(); let doc_ctx = DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for doc_type in type_list { - let explicit_args = explicit_generic_args(&doc_type); - if explicit_args.is_empty() { - continue; - } + check_doc_type_generic_constraints(context, semantic_model, doc_ctx, &doc_type); + } + Some(()) +} - let type_ref = infer_doc_type(doc_ctx, &doc_type); - let generic_type = match type_ref { - LuaType::Generic(generic_type) => generic_type, - _ => continue, - }; +fn check_doc_type_generic_constraints( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + doc_ctx: DocTypeInferContext<'_>, + doc_type: &LuaDocType, +) -> Option<()> { + let LuaDocType::Generic(generic_doc_type) = doc_type else { + return Some(()); + }; + + let explicit_args = explicit_generic_args(generic_doc_type); + if explicit_args.is_empty() { + return Some(()); + } - let generic_params = semantic_model + let name = generic_doc_type.get_name_type()?.get_name_text()?; + let type_decl = semantic_model.get_db().get_type_index().find_type_decl( + semantic_model.get_file_id(), + &name, + semantic_model .get_db() - .get_type_index() - .get_generic_params(&generic_type.get_base_type_id())?; - for (i, param_type) in generic_type - .get_params() - .iter() - .take(explicit_args.len()) - .enumerate() - { - let extend_type = generic_params.get(i)?.type_constraint.clone()?; - let result = semantic_model.type_check_detail(&extend_type, param_type); - if result.is_err() { - add_type_check_diagnostic( - context, - semantic_model, - explicit_args.get(i)?.get_range(), - &extend_type, - param_type, - result, - ); + .resolve_workspace_id(semantic_model.get_file_id()), + )?; + let type_id = type_decl.get_id(); + let generic_params = semantic_model + .get_db() + .get_type_index() + .get_generic_params(&type_id)?; + + let instantiate_arg = explicit_arg_instantiation_flags(&generic_params, explicit_args.len()); + let empty_mapper = TypeMapper::empty(); + let param_types = explicit_args + .iter() + .enumerate() + .map(|(idx, doc_type)| { + let ty = infer_doc_type(doc_ctx, doc_type); + if instantiate_arg.get(idx).copied().unwrap_or(false) { + instantiate_type_generic(semantic_model.get_db(), &ty, &empty_mapper) + } else { + ty } + }) + .collect::>(); + + let mapper = TypeMapper::from_alias(semantic_model.get_db(), param_types.clone(), &type_id); + + for (i, param_type) in param_types.iter().enumerate() { + let Some(explicit_arg) = explicit_args.get(i) else { + continue; + }; + let Some(extend_type) = generic_params + .get(i) + .and_then(|param| param.type_constraint.clone()) + else { + continue; + }; + + let mut extend_type = + instantiate_type_generic(semantic_model.get_db(), &extend_type, &mapper); + extend_type = normalize_keyof_any_constraint(extend_type); + let result = semantic_model.type_check_detail(&extend_type, param_type); + if result.is_err() { + add_type_check_diagnostic( + context, + semantic_model, + explicit_arg.get_range(), + &extend_type, + param_type, + result, + ); } } + Some(()) } -fn explicit_generic_args(doc_type: &LuaDocType) -> Vec { - let LuaDocType::Generic(generic_doc_type) = doc_type else { - return Vec::new(); - }; - +fn explicit_generic_args(generic_doc_type: &LuaDocGenericType) -> Vec { generic_doc_type .get_generic_types() .map(|type_list| type_list.get_types().collect()) .unwrap_or_default() } +fn explicit_arg_instantiation_flags( + generic_params: &[crate::GenericParam], + explicit_arg_count: usize, +) -> Vec { + let mut flags = vec![false; explicit_arg_count]; + for (constraint_index, param) in generic_params.iter().enumerate().take(explicit_arg_count) { + let Some(constraint) = param.type_constraint.as_ref() else { + continue; + }; + + flags[constraint_index] = true; + for (arg_index, referenced_param) in + generic_params.iter().enumerate().take(explicit_arg_count) + { + let tpl_id = referenced_param + .tpl_id + .unwrap_or(GenericTplId::Type(arg_index as u32)); + if type_contains_tpl_ref(constraint, tpl_id) { + flags[arg_index] = true; + } + } + } + + flags +} + +fn type_contains_tpl_ref(ty: &LuaType, tpl_id: GenericTplId) -> bool { + ty.any_type(|ty| match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::StrTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + _ => false, + }) +} + +fn normalize_keyof_any_constraint(ty: LuaType) -> LuaType { + match ty { + LuaType::Call(alias_call) + if alias_call.get_call_kind() == crate::LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 + && alias_call.get_operands()[0].is_any() => + { + LuaType::from_vec(vec![LuaType::String, LuaType::Integer, LuaType::Number]) + } + _ => ty, + } +} + #[allow(clippy::too_many_arguments)] fn check_param( context: &mut DiagnosticContext, @@ -675,7 +761,7 @@ fn check_param( param_type: &LuaType, args: &[CallConstraintArg], from_union: bool, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> Option<()> { // 应该先通过泛型体操约束到唯一类型再进行检查 match param_type { @@ -683,7 +769,7 @@ fn check_param( let extend_type = str_tpl_ref.get_constraint().cloned().map(|ty| { normalize_constraint_type( semantic_model.get_db(), - instantiate_type_generic(semantic_model.get_db(), &ty, substitutor), + instantiate_type_generic(semantic_model.get_db(), &ty, mapper), ) }); let arg = args.get(param_index)?; @@ -706,7 +792,7 @@ fn check_param( let extend_type = tpl_ref.get_constraint().cloned().map(|ty| { normalize_constraint_type( semantic_model.get_db(), - instantiate_type_generic(semantic_model.get_db(), &ty, substitutor), + instantiate_type_generic(semantic_model.get_db(), &ty, mapper), ) }); let arg_type = args.get(param_index).map(|arg| &arg.check_type); @@ -725,7 +811,7 @@ fn check_param( union_member_type, args, true, - substitutor, + mapper, ); } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index 1ad1c18ef..70456257f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -1,20 +1,22 @@ use std::{ops::Deref, sync::Arc}; use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; -use hashbrown::HashSet; +use hashbrown::{HashMap, HashSet}; use rowan::TextRange; use crate::{ DbIndex, DocTypeInferContext, GenericTpl, GenericTplId, LuaFunctionType, LuaSemanticDeclId, - LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, + LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeMapper, TypeOps, VariadicType, infer_doc_type, }; +use super::TypeMapperValue; + // 泛型约束上下文 pub struct CallConstraintContext { pub params: Vec<(String, Option)>, pub args: Vec, - pub substitutor: TypeSubstitutor, + pub mapper: TypeMapper, } pub struct CallConstraintArg { @@ -30,11 +32,8 @@ pub fn build_call_constraint_context( let doc_func = infer_call_doc_function(semantic_model, call_expr)?; let mut params = doc_func.get_params().to_vec(); let mut args = get_arg_infos(semantic_model, call_expr)?; - let mut substitutor = TypeSubstitutor::new(); let generic_tpls = collect_func_tpl_ids(¶ms); - if !generic_tpls.is_empty() { - substitutor.add_need_infer_tpls(generic_tpls); - } + let mut mapper_builder = CallConstraintMapperBuilder::new(generic_tpls); // 读取显式传入的泛型实参 if let Some(type_list) = call_expr.get_call_generic_type_list() { @@ -42,7 +41,7 @@ pub fn build_call_constraint_context( DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for (idx, doc_type) in type_list.get_types().enumerate() { let ty = infer_doc_type(doc_ctx, &doc_type); - substitutor.insert_type(GenericTplId::Func(idx as u32), ty, true); + mapper_builder.bind_type(GenericTplId::Func(idx as u32), ty); } } @@ -65,12 +64,12 @@ pub fn build_call_constraint_context( } } - collect_generic_assignments(&mut substitutor, ¶ms, &args); + collect_generic_assignments(&mut mapper_builder, ¶ms, &args); Some(CallConstraintContext { params, args, - substitutor, + mapper: mapper_builder.into_mapper(), }) } @@ -84,7 +83,7 @@ pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { // 收集各个参数对应的泛型推导 fn collect_generic_assignments( - substitutor: &mut TypeSubstitutor, + mapper_builder: &mut CallConstraintMapperBuilder, params: &[(String, Option)], args: &[CallConstraintArg], ) { @@ -95,7 +94,7 @@ fn collect_generic_assignments( let Some(arg) = args.get(idx) else { continue; }; - record_generic_assignment(param_type, &arg.check_type, substitutor); + record_generic_assignment(param_type, &arg.check_type, mapper_builder); } } @@ -256,31 +255,85 @@ fn collect_generic_tpl_from_fallback( fn record_generic_assignment( param_type: &LuaType, arg_type: &LuaType, - substitutor: &mut TypeSubstitutor, + mapper_builder: &mut CallConstraintMapperBuilder, ) { match param_type { LuaType::TplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), true); + mapper_builder.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::ConstTplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), false); + mapper_builder.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::StrTplRef(str_tpl_ref) => { - substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone(), true); + mapper_builder.bind_type(str_tpl_ref.get_tpl_id(), arg_type.clone()); } LuaType::Variadic(variadic) => { if let Some(inner) = variadic.get_type(0) { - record_generic_assignment(inner, arg_type, substitutor); + record_generic_assignment(inner, arg_type, mapper_builder); } } _ => {} } } +struct CallConstraintMapperBuilder { + bindings: Vec<(GenericTplId, Option)>, + binding_indices: HashMap, +} + +impl CallConstraintMapperBuilder { + fn new(generic_tpls: HashSet) -> Self { + let mut bindings = Vec::with_capacity(generic_tpls.len()); + let mut binding_indices = HashMap::with_capacity(generic_tpls.len()); + for tpl_id in generic_tpls { + binding_indices.insert(tpl_id, bindings.len()); + bindings.push((tpl_id, None)); + } + + Self { + bindings, + binding_indices, + } + } + + fn bind_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { + if tpl_id.is_conditional_infer() { + return; + } + + if let Some(index) = self.binding_indices.get(&tpl_id).copied() { + if let Some((_, existing_type)) = self.bindings.get_mut(index) + && existing_type.is_none() + { + *existing_type = Some(replace_type); + } + return; + } + + self.binding_indices.insert(tpl_id, self.bindings.len()); + self.bindings.push((tpl_id, Some(replace_type))); + } + + fn into_mapper(self) -> TypeMapper { + let (sources, targets): (Vec<_>, Vec<_>) = self + .bindings + .into_iter() + .map(|(tpl_id, ty)| { + ( + tpl_id, + ty.map(TypeMapperValue::type_value) + .unwrap_or(TypeMapperValue::None), + ) + }) + .unzip(); + TypeMapper::from_values(sources, targets) + } +} + // 解析冒号调用时调用者的具体类型 fn infer_call_source_type( semantic_model: &SemanticModel, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs new file mode 100644 index 000000000..c111c832a --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/context.rs @@ -0,0 +1,460 @@ +use std::{cell::RefCell, rc::Rc, sync::Arc}; + +use emmylua_parser::LuaCallExpr; +use hashbrown::{HashMap, HashSet}; + +use crate::{ + DbIndex, GenericTpl, GenericTplId, LuaInferCache, LuaType, check_type_compact, + instantiate_type_generic, semantic::generic::regularize_tpl_candidate_type, +}; + +use crate::semantic::generic::{TypeMapper, TypeMapperValue}; + +use super::{ + InferenceCandidate, InferenceCandidateView, InferencePriority, InferenceResult, + InferenceVariance, is_tpl_at_top_level, resolve::resolve_info, +}; + +#[derive(Debug, Clone)] +pub(super) struct InferenceRecord { + pub(super) candidate: InferenceCandidate, + pub(super) priority: InferencePriority, +} + +#[derive(Debug, Clone, Default)] +pub(super) struct InferenceInfo { + // 协变候选, 例如参数位置或返回值位置的正向推断. + pub(super) covariant: Vec, + // 逆变候选, 例如函数参数约束的反向传播. + pub(super) contravariant: Vec, + // 多返回值或变参展开时收集的候选. + pub(super) multi: Vec, + // 已经固定下来的结果, 例如显式泛型或变参参数表. + pub(super) fixed: Option, + // 最近一次写入的推断结果, 便于在后续阶段判断是否需要重算. + pub(super) inferred: Option, + // 是否已经被 fixing mapper 读取并固定. + pub(super) is_fixed: bool, + // 当前信息是否始终处于顶层位置. + pub(super) top_level: bool, + // 该模板收到的最高优先级. + pub(super) priority: Option, +} + +impl InferenceInfo { + fn new() -> Self { + Self { + top_level: true, + ..Self::default() + } + } + + fn add_candidate( + &mut self, + variance: InferenceVariance, + candidate: InferenceCandidate, + top_level: bool, + priority: InferencePriority, + ) { + if self.is_fixed { + return; + } + + self.top_level &= top_level; + self.priority = Some( + self.priority + .map_or(priority, |current| current.max(priority)), + ); + self.inferred = None; + let record = InferenceRecord { + candidate, + priority, + }; + match variance { + InferenceVariance::Covariant => self.covariant.push(record), + InferenceVariance::Contravariant => self.contravariant.push(record), + } + } +} + +#[derive(Debug)] +pub(in crate::semantic::generic) struct InferenceContext<'a> { + pub db: &'a DbIndex, + pub cache: &'a mut LuaInferCache, + pub call_expr: Option, + infos: HashMap, +} + +impl<'a> InferenceContext<'a> { + pub fn new( + db: &'a DbIndex, + cache: &'a mut LuaInferCache, + call_expr: Option, + ) -> Self { + Self { + db, + cache, + call_expr, + infos: HashMap::new(), + } + } + + pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { + for tpl_id in tpl_ids { + if tpl_id.is_conditional_infer() { + continue; + } + + self.infos.entry(tpl_id).or_insert_with(InferenceInfo::new); + } + } + + pub fn has_unresolved_inference_slots(&self) -> bool { + self.infos.values().any(|info| { + info.fixed.is_none() + && info.covariant.is_empty() + && info.contravariant.is_empty() + && info.multi.is_empty() + }) + } + + pub fn fix_type(&mut self, tpl_id: GenericTplId, ty: LuaType) { + if tpl_id.is_conditional_infer() { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + info.fixed = Some(InferenceResult::Type(ty)); + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + + pub fn add_variadic_params( + &mut self, + tpl_id: GenericTplId, + params: Vec<(String, Option)>, + ) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + return; + } + + info.fixed = Some(InferenceResult::VariadicParams(params)); + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + + pub fn add_variadic_base(&mut self, tpl_id: GenericTplId, base: LuaType) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + return; + } + + info.fixed = Some(InferenceResult::VariadicBase(base)); + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + + pub fn insert_type( + &mut self, + tpl_id: GenericTplId, + candidate: InferenceCandidate, + variance: InferenceVariance, + top_level: bool, + priority: InferencePriority, + ) { + if !self.can_bind(tpl_id) { + return; + } + + self.infos + .entry(tpl_id) + .or_default() + .add_candidate(variance, candidate, top_level, priority); + } + + pub fn insert_multi_types( + &mut self, + tpl_id: GenericTplId, + types: Vec, + view: InferenceCandidateView, + top_level: bool, + priority: InferencePriority, + ) { + if !self.can_bind(tpl_id) { + return; + } + + let info = self.infos.entry(tpl_id).or_default(); + let fixed = if types.len() == 1 { + Some(InferenceResult::VariadicParams(vec![( + "var0".to_string(), + Some(types[0].clone()), + )])) + } else { + Some(InferenceResult::MultiTypes(types.clone())) + }; + if info.fixed.is_some() + || !info.multi.is_empty() + || !info.covariant.is_empty() + || !info.contravariant.is_empty() + { + if top_level { + info.fixed = fixed; + info.inferred = info.fixed.clone(); + info.is_fixed = true; + } + return; + } + + info.multi = types + .into_iter() + .map(|ty| InferenceRecord { + candidate: InferenceCandidate { ty, view }, + priority, + }) + .collect(); + info.top_level &= top_level; + info.priority = Some( + info.priority + .map_or(priority, |current| current.max(priority)), + ); + info.inferred = None; + } + + pub(super) fn inferred_variadic_len(&self, tpl_id: GenericTplId) -> Option { + let info = self.infos.get(&tpl_id)?; + if let Some(fixed) = &info.fixed { + return match fixed { + InferenceResult::Type(_) => Some(1), + InferenceResult::MultiTypes(types) => Some(types.len().max(1)), + InferenceResult::VariadicParams(params) => Some(params.len().max(1)), + InferenceResult::VariadicBase(_) => None, + }; + } + + if !info.multi.is_empty() { + return Some(info.multi.len().max(1)); + } + + if !info.covariant.is_empty() || !info.contravariant.is_empty() { + return Some(1); + } + + None + } + + pub fn fixing_mapper<'b>( + &mut self, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) -> TypeMapper { + self.mapper_inner(generic_tpls, return_type, true) + } + + pub fn non_fixing_mapper<'b>( + &mut self, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) -> TypeMapper { + self.mapper_inner(generic_tpls, return_type, false) + } + + fn mapper_inner<'b>( + &mut self, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + fixing: bool, + ) -> TypeMapper { + let generic_tpls = generic_tpls.into_iter().collect::>(); + let sources = generic_tpls + .iter() + .map(|tpl| tpl.get_tpl_id()) + .collect::>(); + let mut fallback_indices = HashMap::with_capacity(generic_tpls.len()); + let fallback_targets = generic_tpls + .iter() + .enumerate() + .map(|(index, tpl)| { + fallback_indices.entry(tpl.get_tpl_id()).or_insert(index); + self.infos + .get(&tpl.get_tpl_id()) + .and_then(|info| info.inferred.clone().or_else(|| info.fixed.clone())) + .map(inference_result_to_mapper_value) + .unwrap_or(TypeMapperValue::None) + }) + .collect::>(); + let fallback_targets = Rc::new(RefCell::new(fallback_targets)); + let fallback_mapper = TypeMapper::from_inference_fallback( + Rc::new(fallback_indices), + fallback_targets.clone(), + ); + let targets = (0..generic_tpls.len()) + .map(|index| { + let value = self.get_inferred_mapper_value( + &generic_tpls, + index, + return_type, + &fallback_mapper, + fixing, + ); + if value != TypeMapperValue::None + && let Some(target) = fallback_targets.borrow_mut().get_mut(index) + { + *target = value.clone(); + } + value + }) + .collect(); + TypeMapper::from_values(sources, targets) + } + + fn can_bind(&self, tpl_id: GenericTplId) -> bool { + !tpl_id.is_conditional_infer() + } + + fn get_inferred_mapper_value( + &mut self, + generic_tpls: &[&Arc], + index: usize, + return_type: &LuaType, + fallback_mapper: &TypeMapper, + fixing: bool, + ) -> TypeMapperValue { + let tpl = generic_tpls[index]; + let tpl_id = tpl.get_tpl_id(); + let return_top_level = is_tpl_at_top_level(self.db, return_type, tpl_id); + + if let Some(info) = self.infos.get_mut(&tpl_id) + && let Some(inferred) = info.inferred.clone() + { + if fixing { + info.is_fixed = true; + } + return inference_result_to_mapper_value(inferred); + } + + let inferred = self.infos.get_mut(&tpl_id).and_then(|info| { + if fixing { + info.is_fixed = true; + } + + if let Some(inferred) = &info.inferred { + return Some(inferred.clone()); + } + + let result = resolve_info(self.db, tpl, info, return_top_level, fallback_mapper)?; + info.inferred = Some(result.clone()); + Some(result) + }); + let value = if let Some(inferred) = inferred { + let value = inference_result_to_mapper_value(inferred); + apply_inferred_constraint(self.db, tpl, value, fallback_mapper) + } else { + TypeMapperValue::None + }; + + if value != TypeMapperValue::None + && let Some(info) = self.infos.get_mut(&tpl_id) + { + info.inferred = Some(mapper_value_to_inference_result(value.clone())); + } + + value + } +} + +pub(super) fn inference_result_to_mapper_value(result: InferenceResult) -> TypeMapperValue { + match result { + InferenceResult::Type(ty) => TypeMapperValue::type_value(ty), + InferenceResult::MultiTypes(types) => TypeMapperValue::MultiTypes( + types + .into_iter() + .map(|ty| { + TypeMapperValue::type_value(ty) + .raw_type() + .unwrap_or(LuaType::Unknown) + }) + .collect(), + ), + InferenceResult::VariadicParams(params) => TypeMapperValue::params_value(params), + InferenceResult::VariadicBase(base) => TypeMapperValue::MultiBase( + TypeMapperValue::type_value(base) + .raw_type() + .unwrap_or(LuaType::Unknown), + ), + } +} + +fn mapper_value_to_inference_result(value: TypeMapperValue) -> InferenceResult { + match value { + TypeMapperValue::None => InferenceResult::Type(LuaType::Unknown), + TypeMapperValue::Type(ty) => InferenceResult::Type(ty), + TypeMapperValue::Params(params) => InferenceResult::VariadicParams(params), + TypeMapperValue::MultiTypes(types) => InferenceResult::MultiTypes(types), + TypeMapperValue::MultiBase(base) => InferenceResult::VariadicBase(base), + } +} + +fn apply_inferred_constraint( + db: &DbIndex, + tpl: &GenericTpl, + value: TypeMapperValue, + mapper: &TypeMapper, +) -> TypeMapperValue { + let TypeMapperValue::Type(inferred_type) = &value else { + return value; + }; + + let Some(constraint) = tpl.get_constraint() else { + return value; + }; + + let instantiated_constraint = instantiate_type_generic(db, constraint, mapper); + if inferred_satisfies_constraint(db, inferred_type, &instantiated_constraint) { + return value; + } + + value +} + +fn inferred_satisfies_constraint(db: &DbIndex, inferred: &LuaType, constraint: &LuaType) -> bool { + if inferred == constraint || check_type_compact(db, inferred, constraint).is_ok() { + return true; + } + + match constraint { + LuaType::Union(union) => { + return union + .into_vec() + .iter() + .any(|member| inferred_satisfies_constraint(db, inferred, member)); + } + LuaType::MultiLineUnion(union) => { + return union + .get_unions() + .iter() + .any(|(member, _)| inferred_satisfies_constraint(db, inferred, member)); + } + _ => {} + } + + let regular = regularize_tpl_candidate_type(db, inferred.clone()); + regular != *inferred && inferred_satisfies_constraint(db, ®ular, constraint) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs new file mode 100644 index 000000000..4cadc27b2 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/infer_types.rs @@ -0,0 +1,1606 @@ +use std::{collections::HashMap as StdHashMap, ops::Deref, sync::Arc}; + +use emmylua_parser::{LuaAstNode, LuaExpr}; +use itertools::Itertools; +use rowan::NodeOrToken; +use smol_str::SmolStr; + +use crate::{ + GenericTplId, InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, LuaGenericType, + LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, LuaTupleType, LuaType, + LuaTypeNode, LuaUnionType, SemanticDeclLevel, VariadicType, check_type_compact, + infer_node_semantic_decl, instantiate_type_generic, + semantic::{ + generic::TypeMapper, + member::{find_index_operations, get_member_map}, + }, +}; + +use super::{ + InferenceCandidate, InferenceCandidateView, InferenceContext, InferencePriority, + InferenceVariance, escape_alias, get_str_tpl_infer_type, +}; + +pub(in crate::semantic::generic) fn infer_types( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + infer_types_inner( + context, + source, + target, + original_target, + variance, + priority, + None, + &InferGuard::new(), + ) +} + +pub(in crate::semantic::generic) fn infer_types_from_expr( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + arg_expr: &LuaExpr, +) -> Result<(), InferFailReason> { + infer_types_inner( + context, + source, + target, + original_target, + InferenceVariance::Covariant, + InferencePriority::Normal, + Some(arg_expr), + &InferGuard::new(), + ) +} + +fn infer_types_inner( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let target = escape_alias(context.db, target); + if !source.contains_tpl_node() { + return Ok(()); + } + + let top_level = target == *original_target; + match source { + LuaType::TplRef(tpl) => { + if tpl.get_tpl_id().is_func() { + let candidate = match variance { + InferenceVariance::Covariant => { + InferenceCandidate::from_expr_arg(arg_expr, target.clone()) + } + InferenceVariance::Contravariant => InferenceCandidate::ordinary(target), + }; + context.insert_type(tpl.get_tpl_id(), candidate, variance, top_level, priority); + } + } + LuaType::ConstTplRef(tpl) => { + if tpl.get_tpl_id().is_func() { + context.insert_type( + tpl.get_tpl_id(), + InferenceCandidate::const_preserving(target), + variance, + top_level, + priority, + ); + } + } + LuaType::StrTplRef(str_tpl) => { + if let LuaType::StringConst(s) | LuaType::DocStringConst(s) = target { + let type_name = SmolStr::new(format!( + "{}{}{}", + str_tpl.get_prefix(), + s, + str_tpl.get_suffix() + )); + context.insert_type( + str_tpl.get_tpl_id(), + InferenceCandidate::regular_type(get_str_tpl_infer_type(&type_name)), + variance, + top_level, + priority, + ); + } + } + LuaType::Array(array_type) => { + array_infer_types( + context, + array_type.get_base(), + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::TableGeneric(params) => { + table_generic_infer_types( + context, + params, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Generic(generic) => { + generic_infer_types( + context, + generic, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Union(union) => { + let members = union.into_vec(); + let mut error_count = 0; + let mut last_error = InferFailReason::None; + for member in &members { + match infer_types_inner( + context, + member, + &target, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + ) { + Ok(_) => {} + Err(err) => { + error_count += 1; + last_error = err; + } + } + } + if error_count == members.len() { + return Err(last_error); + } + } + LuaType::DocFunction(func) => { + function_infer_types( + context, + func, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Tuple(tuple) => { + tuple_infer_types( + context, + tuple, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Object(object) => { + object_infer_types( + context, + object, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +fn array_infer_types( + context: &mut InferenceContext, + base: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Array(target_array) => infer_types_inner( + context, + base, + target_array.get_base(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?, + LuaType::Tuple(target_tuple) => { + let target_base = target_tuple.cast_down_array_base(context.db); + infer_types_inner( + context, + base, + &target_base, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Object(target_object) => { + let target_base = target_object + .cast_down_array_base(context.db) + .ok_or(InferFailReason::None)?; + infer_types_inner( + context, + base, + &target_base, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +fn table_generic_infer_types( + context: &mut InferenceContext, + table_params: &[LuaType], + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + if table_params.len() != 2 { + return Err(InferFailReason::None); + } + + match target { + LuaType::TableGeneric(target_params) => { + let min_len = table_params.len().min(target_params.len()); + for i in 0..min_len { + infer_types_inner( + context, + &table_params[i], + &target_params[i], + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + LuaType::Array(target_array) => { + infer_types_inner( + context, + &table_params[0], + &LuaType::Integer, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + target_array.get_base(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Tuple(target_tuple) => { + let keys = (0..target_tuple.get_types().len()) + .map(|i| LuaType::IntegerConst((i as i64) + 1)) + .collect::>(); + let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); + let target_base = target_tuple.cast_down_array_base(context.db); + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &target_base, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::TableConst(inst) => { + table_generic_member_owner_infer_types( + context, + table_params, + LuaMemberOwner::Element(inst.clone()), + &[], + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + table_generic_member_owner_infer_types( + context, + table_params, + LuaMemberOwner::Type(type_id.clone()), + &[], + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Generic(generic) => { + table_generic_member_owner_infer_types( + context, + table_params, + LuaMemberOwner::Type(generic.get_base_type_id()), + generic.get_params(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Object(obj) => { + let mut keys = + Vec::with_capacity(obj.get_fields().len() + obj.get_index_access().len()); + let mut values = + Vec::with_capacity(obj.get_fields().len() + obj.get_index_access().len()); + for (key, value) in obj.get_fields() { + match key { + LuaMemberKey::Integer(i) => keys.push(LuaType::IntegerConst(*i)), + LuaMemberKey::Name(name) => { + keys.push(LuaType::StringConst(name.clone().into())) + } + _ => {} + } + values.push(value.clone()); + } + for (key, value) in obj.get_index_access() { + keys.push(key.clone()); + values.push(value.clone()); + } + let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); + let value_type = LuaType::Union(LuaUnionType::from_vec(values).into()); + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &value_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Global | LuaType::Any | LuaType::Table | LuaType::Userdata => { + infer_types_inner( + context, + &table_params[0], + &LuaType::Any, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &LuaType::Any, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn table_generic_member_owner_infer_types( + context: &mut InferenceContext, + table_params: &[LuaType], + owner: LuaMemberOwner, + target_params: &[LuaType], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + if table_params.len() != 2 { + return Err(InferFailReason::None); + } + + let owner_type = match &owner { + LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), + LuaMemberOwner::Type(type_id) => match target_params.len() { + 0 => LuaType::Ref(type_id.clone()), + _ => LuaType::Generic(Arc::new(LuaGenericType::new( + type_id.clone(), + target_params.to_vec(), + ))), + }, + _ => return Err(InferFailReason::None), + }; + + let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; + if is_pairs_call(context).unwrap_or(false) + && try_handle_pairs_metamethod( + context, + table_params, + &members, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ) + .is_ok() + { + return Ok(()); + } + + let target_key_type = table_params[0].clone(); + let mut keys = Vec::with_capacity(members.len()); + let mut values = Vec::with_capacity(members.len()); + for (key, members) in members { + let key_type = match key { + LuaMemberKey::Integer(i) => LuaType::IntegerConst(i), + LuaMemberKey::Name(name) => LuaType::StringConst(name.clone().into()), + LuaMemberKey::ExprType(ty) => ty, + _ => continue, + }; + + if !target_key_type.is_generic() + && check_type_compact(context.db, &target_key_type, &key_type).is_err() + { + continue; + } + + keys.push(key_type); + values.push(member_infos_type(members)); + } + + if keys.is_empty() { + find_index_operations(context.db, &owner_type) + .ok_or(InferFailReason::None)? + .iter() + .for_each(|member| { + if target_key_type.is_generic() { + return; + } + let LuaMemberKey::ExprType(key_type) = &member.key else { + return; + }; + if check_type_compact(context.db, &target_key_type, key_type).is_ok() { + keys.push(key_type.clone()); + values.push(member.typ.clone()); + } + }); + } + + let key_type = match &keys[..] { + [] => return Err(InferFailReason::None), + [first] => first.clone(), + _ => LuaType::Union(LuaUnionType::from_vec(keys).into()), + }; + let value_type = match &values[..] { + [first] => first.clone(), + _ => LuaType::Union(LuaUnionType::from_vec(values).into()), + }; + + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &value_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ) +} + +fn generic_infer_types( + context: &mut InferenceContext, + source_generic: &LuaGenericType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Generic(target_generic) => { + let source_base = source_generic.get_base_type_id_ref(); + let target_base = target_generic.get_base_type_id_ref(); + if source_base == target_base { + for (start, (source_param, target_param)) in source_generic + .get_params() + .iter() + .zip(target_generic.get_params()) + .enumerate() + { + match source_param { + LuaType::Variadic(variadic) => { + variadic_infer_types( + context, + variadic, + &target_generic.get_params()[start..], + original_target, + variance, + priority, + )?; + break; + } + _ => infer_types_inner( + context, + source_param, + target_param, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?, + } + } + return Ok(()); + } + + let target_decl = context + .db + .get_type_index() + .get_type_decl(target_base) + .ok_or(InferFailReason::None)?; + if target_decl.is_alias() { + let mapper = TypeMapper::from_alias( + context.db, + target_generic.get_params().clone(), + target_base, + ); + if let Some(origin_type) = target_decl.get_alias_origin(context.db, Some(&mapper)) { + return generic_infer_types( + context, + source_generic, + &origin_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ); + } + } else if let Some(super_types) = + context.db.get_type_index().get_super_types(target_base) + { + for mut super_type in super_types { + if super_type.contains_tpl_node() { + let mapper = + TypeMapper::from_type_array(target_generic.get_params().clone()); + super_type = instantiate_type_generic(context.db, &super_type, &mapper); + } + generic_infer_types( + context, + source_generic, + &super_type, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + )?; + } + } + } + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + infer_guard.check(type_id)?; + let type_decl = context + .db + .get_type_index() + .get_type_decl(type_id) + .ok_or(InferFailReason::None)?; + if let Some(origin_type) = type_decl.get_alias_origin(context.db, None) { + return generic_infer_types( + context, + source_generic, + &origin_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + ); + } + + for super_type in context + .db + .get_type_index() + .get_super_types(type_id) + .unwrap_or_default() + { + generic_infer_types( + context, + source_generic, + &super_type, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + )?; + } + } + LuaType::Union(union) => { + for member in union.into_vec() { + generic_infer_types( + context, + source_generic, + &member, + original_target, + variance, + priority, + arg_expr, + &infer_guard.fork(), + )?; + } + } + _ => { + let mapper = TypeMapper::empty(); + let generic_ty = LuaType::Generic(source_generic.clone().into()); + let ty = instantiate_type_generic(context.db, &generic_ty, &mapper); + if LuaType::from(source_generic.clone()) != ty { + infer_types_inner( + context, + &ty, + target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + } + + Ok(()) +} + +fn function_infer_types( + context: &mut InferenceContext, + source_func: &LuaFunctionType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::DocFunction(target_func) => { + function_doc_infer_types( + context, + source_func, + target_func, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + LuaType::Signature(signature_id) => { + let signature = context + .db + .get_signature_index() + .get(signature_id) + .ok_or(InferFailReason::None)?; + if !signature.is_resolve_return() { + return check_lambda_inference(context, *signature_id); + } + + let fake_func = signature.to_doc_func_type(); + function_doc_infer_types( + context, + source_func, + &fake_func, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +fn function_doc_infer_types( + context: &mut InferenceContext, + source_func: &LuaFunctionType, + target_func: &LuaFunctionType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let mut source_params = source_func.get_params().to_vec(); + if source_func.is_colon_define() { + source_params.insert(0, ("self".to_string(), Some(LuaType::Any))); + } + + let mut target_params = target_func.get_params().to_vec(); + if target_func.is_colon_define() { + target_params.insert(0, ("self".to_string(), Some(LuaType::Any))); + } + + param_list_infer_types( + context, + &source_params, + &target_params, + original_target, + variance.flip(), + priority, + arg_expr, + infer_guard, + )?; + return_type_infer_types( + context, + source_func.get_ret(), + target_func.get_ret(), + original_target, + variance, + priority, + arg_expr, + infer_guard, + ) +} + +#[allow(clippy::too_many_arguments)] +fn param_list_infer_types( + context: &mut InferenceContext, + sources: &[(String, Option)], + targets: &[(String, Option)], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let mut target_offset = 0; + for i in 0..sources.len() { + let source = match sources.get(i) { + Some((_, ty)) => ty.clone().unwrap_or(LuaType::Any), + None => break, + }; + + match &source { + LuaType::Variadic(inner) => { + let i = i + target_offset; + if i >= targets.len() { + if let VariadicType::Base(base) = inner.deref() { + insert_tpl_ref_candidate( + context, + base, + LuaType::Nil, + variance, + false, + priority, + ); + } + break; + } + + if let Some((tpl_id, _)) = variadic_base_tpl_ref(inner.deref()) + && let Some(len) = context.inferred_variadic_len(tpl_id) + { + target_offset += len - 1; + continue; + } + + let mut target_rest_params = &targets[i..]; + if i + 1 < sources.len() { + let source_rest_len = sources.len() - i - 1; + if source_rest_len >= target_rest_params.len() { + continue; + } + let target_rest_len = target_rest_params.len() - source_rest_len; + target_rest_params = &target_rest_params[..target_rest_len]; + if target_rest_len > 1 { + target_offset += target_rest_len - 1; + } + } + + function_varargs_infer_types(context, inner, target_rest_params)?; + } + _ => { + let target = match targets.get(i + target_offset) { + Some((_, ty)) => ty.clone().unwrap_or(LuaType::Any), + None => break, + }; + infer_types_inner( + context, + &source, + &target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub(in crate::semantic::generic) fn return_type_infer_types( + context: &mut InferenceContext, + source: &LuaType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match (source, target) { + (LuaType::Variadic(source_variadic), LuaType::Variadic(target_variadic)) => { + match target_variadic.deref() { + VariadicType::Base(target_base) => match source_variadic.deref() { + VariadicType::Base(source_base) => { + insert_tpl_ref_candidate( + context, + source_base, + target_base.clone(), + variance, + false, + priority, + ); + } + VariadicType::Multi(source_multi) => { + for ret_type in source_multi { + match ret_type { + LuaType::Variadic(inner) => { + if let VariadicType::Base(base) = inner.deref() { + insert_tpl_ref_candidate( + context, + base, + target_base.clone(), + variance, + false, + priority, + ); + } + break; + } + _ => { + insert_tpl_ref_candidate( + context, + ret_type, + target_base.clone(), + variance, + false, + priority, + ); + } + } + } + } + }, + VariadicType::Multi(target_types) => { + variadic_infer_types( + context, + source_variadic, + target_types, + original_target, + variance, + priority, + )?; + } + } + } + (LuaType::Variadic(variadic), _) => { + variadic_infer_types( + context, + variadic, + std::slice::from_ref(target), + original_target, + variance, + priority, + )?; + } + (_, LuaType::Variadic(variadic)) => { + multi_param_infer_multi_return( + context, + std::slice::from_ref(source), + variadic, + original_target, + variance, + priority, + )?; + } + _ => infer_types_inner( + context, + source, + target, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?, + } + + Ok(()) +} + +fn tpl_ref_info(ty: &LuaType) -> Option<(GenericTplId, bool)> { + match ty { + LuaType::TplRef(tpl_ref) => Some((tpl_ref.get_tpl_id(), false)), + LuaType::ConstTplRef(tpl_ref) => Some((tpl_ref.get_tpl_id(), true)), + _ => None, + } +} + +fn variadic_base_tpl_ref(variadic: &VariadicType) -> Option<(GenericTplId, bool)> { + let VariadicType::Base(base) = variadic else { + return None; + }; + tpl_ref_info(base) +} + +fn tpl_ref_candidate(is_const_tpl: bool, ty: LuaType) -> InferenceCandidate { + if is_const_tpl { + InferenceCandidate::const_preserving(ty) + } else { + InferenceCandidate::ordinary(ty) + } +} + +fn insert_tpl_ref_candidate( + context: &mut InferenceContext, + source: &LuaType, + target: LuaType, + variance: InferenceVariance, + top_level: bool, + priority: InferencePriority, +) { + if let Some((tpl_id, is_const_tpl)) = tpl_ref_info(source) { + context.insert_type( + tpl_id, + tpl_ref_candidate(is_const_tpl, target), + variance, + top_level, + priority, + ); + } +} + +fn function_varargs_infer_types( + context: &mut InferenceContext, + variadic: &VariadicType, + target_rest_params: &[(String, Option)], +) -> Result<(), InferFailReason> { + if let Some((tpl_id, _)) = variadic_base_tpl_ref(variadic) { + context.add_variadic_params( + tpl_id, + target_rest_params + .iter() + .map(|(name, ty)| (name.clone(), ty.clone())) + .collect(), + ); + } + + Ok(()) +} + +pub(in crate::semantic::generic) fn variadic_infer_types( + context: &mut InferenceContext, + source: &VariadicType, + target_rest_types: &[LuaType], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + match source { + VariadicType::Base(base) => match base { + LuaType::TplRef(tpl_ref) => { + let tpl_id = tpl_ref.get_tpl_id(); + match target_rest_types.len() { + 0 => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Nil), + variance, + false, + priority, + ), + 1 => match &target_rest_types[0] { + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Multi(types) => match types.len() { + 0 => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Nil), + variance, + false, + priority, + ), + 1 => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(types[0].clone()), + variance, + false, + priority, + ), + _ => context.insert_multi_types( + tpl_id, + types.to_vec(), + InferenceCandidateView::Ordinary, + false, + priority, + ), + }, + VariadicType::Base(base) => { + context.add_variadic_base(tpl_id, base.clone()); + } + }, + target => context.insert_type( + tpl_id, + InferenceCandidate::ordinary(target.clone()), + variance, + false, + priority, + ), + }, + _ => context.insert_multi_types( + tpl_id, + target_rest_types.to_vec(), + InferenceCandidateView::Ordinary, + false, + priority, + ), + } + } + LuaType::ConstTplRef(tpl_ref) => { + let tpl_id = tpl_ref.get_tpl_id(); + match target_rest_types.len() { + 0 => context.insert_type( + tpl_id, + InferenceCandidate::const_preserving(LuaType::Nil), + variance, + false, + priority, + ), + 1 => context.insert_type( + tpl_id, + InferenceCandidate::const_preserving(target_rest_types[0].clone()), + variance, + false, + priority, + ), + _ => context.insert_multi_types( + tpl_id, + target_rest_types.to_vec(), + InferenceCandidateView::ConstPreserving, + false, + priority, + ), + } + } + _ => {} + }, + VariadicType::Multi(multi) => { + for (i, ret_type) in multi.iter().enumerate() { + match ret_type { + LuaType::Variadic(inner) => { + if i < target_rest_types.len() { + variadic_infer_types( + context, + inner, + &target_rest_types[i..], + original_target, + variance, + priority, + )?; + } + break; + } + _ => { + let Some(target) = target_rest_types.get(i) else { + break; + }; + insert_tpl_ref_candidate( + context, + ret_type, + target.clone(), + variance, + target == original_target, + priority, + ); + } + } + } + } + } + + Ok(()) +} + +pub(in crate::semantic::generic) fn multi_param_infer_multi_return( + context: &mut InferenceContext, + source_params: &[LuaType], + multi_return: &VariadicType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + match multi_return { + VariadicType::Base(base) => { + let mut target_types = Vec::with_capacity(source_params.len()); + for param in source_params { + if param.is_variadic() { + target_types.push(LuaType::Variadic(multi_return.clone().into())); + break; + } else { + target_types.push(base.clone()); + } + } + infer_type_list( + context, + source_params, + &target_types, + original_target, + variance, + priority, + )?; + } + VariadicType::Multi(_) => { + let mut target_types = Vec::with_capacity(source_params.len()); + for (i, param) in source_params.iter().enumerate() { + let Some(return_type) = multi_return.get_type(i) else { + break; + }; + if param.is_variadic() { + target_types.push(LuaType::Variadic( + multi_return.get_new_variadic_from(i).into(), + )); + break; + } else { + target_types.push(return_type.clone()); + } + } + infer_type_list( + context, + source_params, + &target_types, + original_target, + variance, + priority, + )?; + } + } + + Ok(()) +} + +pub(in crate::semantic::generic) fn infer_type_list( + context: &mut InferenceContext, + source_types: &[LuaType], + target_types: &[LuaType], + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, +) -> Result<(), InferFailReason> { + for (start, (source, target)) in source_types.iter().zip(target_types).enumerate() { + match (source, target) { + (LuaType::Variadic(variadic), _) => { + variadic_infer_types( + context, + variadic, + &target_types[start..], + original_target, + variance, + priority, + )?; + break; + } + (_, LuaType::Variadic(variadic)) => { + multi_param_infer_multi_return( + context, + &source_types[start..], + variadic, + original_target, + variance, + priority, + )?; + break; + } + _ => infer_types(context, source, target, original_target, variance, priority)?, + } + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn tuple_infer_types( + context: &mut InferenceContext, + source_tuple: &LuaTupleType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Tuple(target_tuple) => { + for (i, source_type) in source_tuple.get_types().iter().enumerate() { + if let LuaType::Variadic(inner) = source_type { + variadic_infer_types( + context, + inner, + &target_tuple.get_types()[i..], + original_target, + variance, + priority, + )?; + break; + } + let Some(target_type) = target_tuple.get_types().get(i) else { + break; + }; + infer_types_inner( + context, + source_type, + target_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + LuaType::Array(target_array) => { + let Some(last_type) = source_tuple.get_types().last() else { + return Err(InferFailReason::None); + }; + if let LuaType::Variadic(inner) = last_type + && let Some((tpl_id, _)) = variadic_base_tpl_ref(inner.deref()) + { + context.add_variadic_base(tpl_id, target_array.get_base().clone()); + } + } + _ => {} + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn object_infer_types( + context: &mut InferenceContext, + source_obj: &crate::LuaObjectType, + target: &LuaType, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + match target { + LuaType::Object(target_obj) => { + for (key, value) in source_obj + .get_fields() + .iter() + .sorted_by_key(|(key, _)| *key) + { + if let Some(target_value) = target_obj.get_fields().get(key) { + infer_types_inner( + context, + value, + target_value, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + for (source_key, value) in source_obj.get_index_access() { + let target_access = target_obj + .get_index_access() + .iter() + .find(|(target_key, _)| { + check_type_compact(context.db, source_key, target_key).is_ok() + }); + if let Some((target_key, target_value)) = target_access { + infer_types_inner( + context, + source_key, + target_key, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + value, + target_value, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + } + LuaType::TableConst(inst) => { + object_member_owner_infer_types( + context, + source_obj, + LuaMemberOwner::Element(inst.clone()), + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + _ => {} + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn object_member_owner_infer_types( + context: &mut InferenceContext, + source_obj: &crate::LuaObjectType, + owner: LuaMemberOwner, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let owner_type = match &owner { + LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), + LuaMemberOwner::Type(type_id) => LuaType::Ref(type_id.clone()), + _ => return Err(InferFailReason::None), + }; + + let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; + for (key, members) in members { + let resolve_type = member_infos_type(members); + if let Some(field_value) = source_obj.get_field(&key) { + infer_types_inner( + context, + field_value, + &resolve_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + } + } + + Ok(()) +} + +fn member_infos_type(members: Vec) -> LuaType { + match members.len() { + 0 => LuaType::Any, + 1 => members[0].typ.clone(), + _ => LuaType::from_vec(members.into_iter().map(|member| member.typ).collect()), + } +} + +fn is_pairs_call(context: &mut InferenceContext) -> Option { + let call_expr = context.call_expr.as_ref()?; + let prefix_expr = call_expr.get_prefix_expr()?; + let semantic_decl = match prefix_expr.syntax().clone().into() { + NodeOrToken::Node(node) => infer_node_semantic_decl( + context.db, + context.cache, + node, + SemanticDeclLevel::default(), + ), + _ => None, + }?; + + let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl else { + return None; + }; + let decl = context.db.get_decl_index().get_decl(&decl_id)?; + if !context.db.get_module_index().is_std(&decl.get_file_id()) { + return None; + } + if decl.get_name() != "pairs" { + return None; + } + + Some(true) +} + +#[allow(clippy::too_many_arguments)] +fn try_handle_pairs_metamethod( + context: &mut InferenceContext, + table_params: &[LuaType], + members: &StdHashMap>, + original_target: &LuaType, + variance: InferenceVariance, + priority: InferencePriority, + arg_expr: Option<&LuaExpr>, + infer_guard: &InferGuardRef, +) -> Result<(), InferFailReason> { + let pairs_member = members + .get(&LuaMemberKey::Name("__pairs".into())) + .ok_or(InferFailReason::None)? + .first() + .ok_or(InferFailReason::None)?; + + let meta_return = match &pairs_member.typ { + LuaType::Signature(signature_id) => context + .db + .get_signature_index() + .get(signature_id) + .map(|signature| signature.get_return_type()), + LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), + _ => None, + } + .ok_or(InferFailReason::None)?; + + let iterator_func = meta_return.get_result_slot_type(0).unwrap_or(meta_return); + let final_return_type = match iterator_func { + LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), + LuaType::Signature(signature_id) => context + .db + .get_signature_index() + .get(&signature_id) + .map(|signature| signature.get_return_type()), + _ => None, + }; + + if let Some(final_return_type) = &final_return_type { + let key_type = final_return_type + .get_result_slot_type(0) + .ok_or(InferFailReason::None)?; + let value_type = final_return_type + .get_result_slot_type(1) + .unwrap_or(LuaType::Nil); + infer_types_inner( + context, + &table_params[0], + &key_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + infer_types_inner( + context, + &table_params[1], + &value_type, + original_target, + variance, + priority, + arg_expr, + infer_guard, + )?; + return Ok(()); + } + + Err(InferFailReason::None) +} + +fn check_lambda_inference( + context: &mut InferenceContext, + signature_id: crate::LuaSignatureId, +) -> Result<(), InferFailReason> { + let call_expr = context.call_expr.as_ref().ok_or(InferFailReason::None)?; + let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; + for arg in call_arg_list.get_args() { + if let Ok(LuaType::Signature(arg_signature_id)) = + crate::semantic::infer_expr(context.db, context.cache, arg.clone()) + && arg_signature_id == signature_id + { + return Ok(()); + } + } + + Err(InferFailReason::UnResolveSignatureReturn(signature_id)) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs new file mode 100644 index 000000000..32e77624e --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/mod.rs @@ -0,0 +1,251 @@ +use std::ops::Deref; + +use emmylua_parser::LuaExpr; +use hashbrown::HashSet; + +use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId, VariadicType}; + +mod context; +mod infer_types; +mod resolve; + +#[cfg(test)] +mod tests; + +pub(in crate::semantic::generic) use context::InferenceContext; +pub(in crate::semantic::generic) use infer_types::{ + infer_type_list, infer_types_from_expr, multi_param_infer_multi_return, + return_type_infer_types, variadic_infer_types, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum InferenceCandidateView { + FreshExpression, + RegularType, + ConstPreserving, + Ordinary, +} + +#[derive(Debug, Clone, PartialEq)] +pub(in crate::semantic::generic) struct InferenceCandidate { + pub(super) ty: LuaType, + pub(super) view: InferenceCandidateView, +} + +impl InferenceCandidate { + pub(in crate::semantic::generic) fn from_expr_arg(expr: Option<&LuaExpr>, ty: LuaType) -> Self { + if is_literal_candidate(&ty) { + if expr.is_some_and(is_fresh_literal_expr) { + return Self::fresh_expression(ty); + } + + return Self::regular_type(ty); + } + + Self::ordinary(ty) + } + + pub(in crate::semantic::generic) fn regular_type(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::RegularType, + } + } + + pub(in crate::semantic::generic) fn const_preserving(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::ConstPreserving, + } + } + + pub(in crate::semantic::generic) fn ordinary(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::Ordinary, + } + } + + fn fresh_expression(ty: LuaType) -> Self { + Self { + ty, + view: InferenceCandidateView::FreshExpression, + } + } + + pub(super) fn candidate_type(&self) -> LuaType { + self.ty.clone() + } + + pub(super) fn is_const_preserving(&self) -> bool { + self.view == InferenceCandidateView::ConstPreserving + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(in crate::semantic::generic) enum InferencePriority { + Normal, + Return, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum InferenceVariance { + Covariant, + Contravariant, +} + +impl InferenceVariance { + pub(in crate::semantic::generic) fn flip(self) -> Self { + match self { + InferenceVariance::Covariant => InferenceVariance::Contravariant, + InferenceVariance::Contravariant => InferenceVariance::Covariant, + } + } +} + +#[derive(Debug, Clone)] +pub(in crate::semantic::generic) enum InferenceResult { + Type(LuaType), + MultiTypes(Vec), + VariadicParams(Vec<(String, Option)>), + VariadicBase(LuaType), +} + +pub(in crate::semantic::generic) fn is_literal_candidate(ty: &LuaType) -> bool { + match ty { + LuaType::StringConst(_) + | LuaType::DocStringConst(_) + | LuaType::IntegerConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::BooleanConst(_) + | LuaType::DocBooleanConst(_) + | LuaType::TableConst(_) => true, + LuaType::Union(union) => union.into_vec().iter().any(is_literal_candidate), + LuaType::MultiLineUnion(union) => union + .get_unions() + .iter() + .any(|(ty, _)| is_literal_candidate(ty)), + LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_literal_candidate), + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => is_literal_candidate(base), + VariadicType::Multi(types) => types.iter().any(is_literal_candidate), + }, + _ => false, + } +} + +fn is_fresh_literal_expr(expr: &LuaExpr) -> bool { + match expr { + LuaExpr::LiteralExpr(_) | LuaExpr::TableExpr(_) => true, + LuaExpr::ParenExpr(paren) => paren + .get_expr() + .is_some_and(|expr| is_fresh_literal_expr(&expr)), + _ => false, + } +} + +pub(in crate::semantic::generic) fn get_str_tpl_infer_type(name: &str) -> LuaType { + match name { + "unknown" => LuaType::Unknown, + "never" => LuaType::Never, + "nil" | "void" => LuaType::Nil, + "any" => LuaType::Any, + "userdata" => LuaType::Userdata, + "thread" => LuaType::Thread, + "boolean" | "bool" => LuaType::Boolean, + "string" => LuaType::String, + "integer" | "int" => LuaType::Integer, + "number" => LuaType::Number, + "io" => LuaType::Io, + "self" => LuaType::SelfInfer, + "global" => LuaType::Global, + "function" => LuaType::Function, + _ => LuaType::Ref(LuaTypeDeclId::global(name)), + } +} + +pub(in crate::semantic::generic) fn escape_alias(db: &DbIndex, may_alias: &LuaType) -> LuaType { + if let LuaType::Ref(type_id) = may_alias + && let Some(type_decl) = db.get_type_index().get_type_decl(type_id) + && type_decl.is_alias() + && let Some(origin_type) = type_decl.get_alias_origin(db, None) + { + return origin_type.clone(); + } + + may_alias.clone() +} + +pub(super) fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { + is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) +} + +fn is_tpl_at_top_level_with_guard( + db: &DbIndex, + ty: &LuaType, + tpl_id: GenericTplId, + visited_aliases: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::Union(union) => union.into_vec().iter().any(|member| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::Generic(generic) => { + let type_decl_id = generic.get_base_type_id_ref(); + let Some(alias_param) = + get_transparent_alias_param_index(db, type_decl_id, visited_aliases) + else { + return false; + }; + + generic.get_params().get(alias_param).is_some_and(|param| { + is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) + }) + } + _ => false, + } +} + +fn get_transparent_alias_param_index( + db: &DbIndex, + type_decl_id: &LuaTypeDeclId, + visited_aliases: &mut HashSet, +) -> Option { + if !visited_aliases.insert(type_decl_id.clone()) { + return None; + } + + let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; + if !type_decl.is_alias() { + return None; + }; + let origin = type_decl.get_alias_ref()?; + + match origin { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + LuaType::Generic(generic) => { + get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) + .and_then(|alias_param| generic.get_params().get(alias_param)) + .and_then(|param| match param { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + _ => None, + }) + } + _ => None, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs new file mode 100644 index 000000000..07135328e --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/resolve.rs @@ -0,0 +1,135 @@ +use crate::{DbIndex, GenericTpl, LuaType, TypeOps, instantiate_type_generic}; + +use crate::semantic::generic::{ + TypeMapper, is_primitive_or_literal_type, regularize_tpl_candidate_type, + widen_tpl_candidate_type, +}; + +use super::context::{InferenceInfo, InferenceRecord}; +use super::{InferenceCandidate, InferenceCandidateView, InferencePriority, InferenceResult}; + +pub(super) fn resolve_info( + db: &DbIndex, + tpl: &GenericTpl, + info: &InferenceInfo, + return_top_level: bool, + mapper: &TypeMapper, +) -> Option { + if let Some(fixed) = &info.fixed { + return Some(fixed.clone()); + } + + if !info.multi.is_empty() { + let primitive_constraint = tpl.get_constraint().is_some_and(|constraint| { + let constraint = instantiate_type_generic(db, constraint, mapper); + is_primitive_or_literal_type(&constraint) + }); + let const_preserving = info + .multi + .iter() + .any(|record| record.candidate.is_const_preserving()); + let preserve_root_literal_form = + primitive_constraint || const_preserving || return_top_level; + return Some(InferenceResult::MultiTypes( + info.multi + .iter() + .map(|record| { + resolve_candidate_type(db, &record.candidate, preserve_root_literal_form) + }) + .collect(), + )); + } + + if !info.covariant.is_empty() { + return Some(InferenceResult::Type(resolve_covariant_candidates( + db, + tpl, + info, + return_top_level, + mapper, + ))); + } + + if !info.contravariant.is_empty() { + return Some(InferenceResult::Type(combine_records( + db, + &info.contravariant, + info.priority.unwrap_or(InferencePriority::Normal), + true, + TypeOps::Intersect, + ))); + } + + None +} + +fn resolve_covariant_candidates( + db: &DbIndex, + tpl: &GenericTpl, + info: &InferenceInfo, + return_top_level: bool, + mapper: &TypeMapper, +) -> LuaType { + let primitive_constraint = tpl + .get_constraint() + .map(|constraint| { + let constraint = instantiate_type_generic(db, constraint, mapper); + is_primitive_or_literal_type(&constraint) + }) + .unwrap_or(false); + let const_preserving = info + .covariant + .iter() + .any(|record| record.candidate.is_const_preserving()); + let preserve_root_literal_form = + primitive_constraint || const_preserving || !info.top_level || return_top_level; + + combine_records( + db, + &info.covariant, + info.priority.unwrap_or(InferencePriority::Normal), + preserve_root_literal_form, + TypeOps::Union, + ) +} + +fn combine_records( + db: &DbIndex, + records: &[InferenceRecord], + max_priority: InferencePriority, + preserve_root_literal_form: bool, + op: TypeOps, +) -> LuaType { + let mut selected = records + .iter() + .filter(|record| record.priority == max_priority) + .map(|record| resolve_candidate_type(db, &record.candidate, preserve_root_literal_form)); + + let Some(first) = selected.next() else { + return LuaType::Unknown; + }; + + selected.fold(first, |acc, ty| op.apply(db, &acc, &ty)) +} + +fn resolve_candidate_type( + db: &DbIndex, + candidate: &InferenceCandidate, + preserve_root_literal_form: bool, +) -> LuaType { + if candidate.is_const_preserving() { + // std.ConstTpl 需要保留结构字面量, 例如 tuple/object/table const. + return candidate.candidate_type(); + } + + if preserve_root_literal_form { + return regularize_tpl_candidate_type(db, candidate.candidate_type()); + } + + match candidate.view { + InferenceCandidateView::FreshExpression | InferenceCandidateView::Ordinary => { + widen_tpl_candidate_type(db, candidate.candidate_type()) + } + _ => regularize_tpl_candidate_type(db, candidate.candidate_type()), + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs new file mode 100644 index 000000000..3024f4794 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/inference/tests.rs @@ -0,0 +1,196 @@ +use std::sync::Arc; + +use hashbrown::HashSet; +use smol_str::SmolStr; + +use super::infer_types::infer_types; +use super::{ + InferenceCandidate, InferenceContext, InferencePriority, InferenceResult, InferenceVariance, + context::inference_result_to_mapper_value, is_literal_candidate, return_type_infer_types, +}; +use crate::{ + CacheOptions, DbIndex, FileId, GenericTpl, GenericTplId, InferGuard, LuaInferCache, + LuaMultiLineUnion, LuaStringTplType, LuaTupleStatus, LuaTupleType, LuaType, LuaTypeDeclId, + TypeOps, VariadicType, + semantic::generic::{TypeMapperValue, get_mapped_value}, +}; + +#[test] +fn non_fixing_mapper_does_not_lock_later_candidates() { + let db = DbIndex::new(); + let mut cache = LuaInferCache::new(FileId::VIRTUAL, CacheOptions::default()); + let mut context = InferenceContext::new(&db, &mut cache, None); + let tpl_id = GenericTplId::Func(0); + let tpl = Arc::new(GenericTpl::new( + tpl_id, + SmolStr::new("T").into(), + None, + None, + )); + let return_type = LuaType::TplRef(tpl.clone()); + + context.prepare_inference_slots(HashSet::from([tpl_id])); + + // non_fixing_mapper 会读取当前推断结果, 但不应把泛型槽位标记为固定. + context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::String), + InferenceVariance::Covariant, + true, + InferencePriority::Normal, + ); + + let non_fixing = context.non_fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &non_fixing).and_then(|value| value.raw_type()), + Some(LuaType::String) + ); + + // 如果上面的 mapper 错误地固定了 T, 这里的新候选会被忽略; + // 正确行为是继续合并协变候选, 得到 string | integer. + context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Integer), + InferenceVariance::Covariant, + true, + InferencePriority::Normal, + ); + let expected = TypeOps::Union.apply(&db, &LuaType::String, &LuaType::Integer); + let fixing = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &fixing).and_then(|value| value.raw_type()), + Some(expected.clone()) + ); + + // fixing_mapper 才会真正锁定槽位; 锁定后再插入候选不应改变结果. + context.insert_type( + tpl_id, + InferenceCandidate::ordinary(LuaType::Boolean), + InferenceVariance::Covariant, + true, + InferencePriority::Normal, + ); + let fixed = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &fixed).and_then(|value| value.raw_type()), + Some(expected) + ); +} + +#[test] +fn variadic_params_mapper_normalizes_def_types() { + let type_id = LuaTypeDeclId::global("Alias"); + let value = inference_result_to_mapper_value(InferenceResult::VariadicParams(vec![( + "value".to_string(), + Some(LuaType::Def(type_id.clone())), + )])); + + assert_eq!( + value, + TypeMapperValue::Params(vec![("value".to_string(), Some(LuaType::Ref(type_id)))]) + ); +} + +#[test] +fn literal_candidate_detects_multi_line_union_members() { + let literal_union = LuaType::MultiLineUnion( + LuaMultiLineUnion::new(vec![ + (LuaType::DocStringConst(SmolStr::new("left").into()), None), + (LuaType::Integer, Some("fallback".to_string())), + ]) + .into(), + ); + assert!(is_literal_candidate(&literal_union)); + + let non_literal_union = LuaType::MultiLineUnion( + LuaMultiLineUnion::new(vec![ + (LuaType::String, None), + (LuaType::Integer, Some("fallback".to_string())), + ]) + .into(), + ); + assert!(!is_literal_candidate(&non_literal_union)); +} + +#[test] +fn str_tpl_inference_accepts_doc_string_const() { + let db = DbIndex::new(); + let mut cache = LuaInferCache::new(FileId::VIRTUAL, CacheOptions::default()); + let mut context = InferenceContext::new(&db, &mut cache, None); + let tpl_id = GenericTplId::Func(0); + let tpl = Arc::new(GenericTpl::new( + tpl_id, + SmolStr::new("T").into(), + None, + None, + )); + context.prepare_inference_slots(HashSet::from([tpl_id])); + + let source = + LuaType::StrTplRef(LuaStringTplType::new("aaa.", "T", tpl_id, ".bbb", None).into()); + let target = LuaType::DocStringConst(SmolStr::new("xxx").into()); + + infer_types( + &mut context, + &source, + &target, + &target, + InferenceVariance::Covariant, + InferencePriority::Normal, + ) + .expect("string template inference"); + + let return_type = LuaType::TplRef(tpl.clone()); + let mapper = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &mapper).and_then(|value| value.raw_type()), + Some(LuaType::Ref(LuaTypeDeclId::global("aaa.xxx.bbb"))) + ); +} + +#[test] +fn return_variadic_const_tpl_ref_preserves_structural_base() { + let db = DbIndex::new(); + let mut cache = LuaInferCache::new(FileId::VIRTUAL, CacheOptions::default()); + let mut context = InferenceContext::new(&db, &mut cache, None); + let tpl_id = GenericTplId::Func(0); + let tpl = Arc::new(GenericTpl::new( + tpl_id, + SmolStr::new("T").into(), + None, + None, + )); + context.prepare_inference_slots(HashSet::from([tpl_id])); + + let literal_tuple = LuaType::Tuple( + LuaTupleType::new( + vec![ + LuaType::StringConst(SmolStr::new("mode").into()), + LuaType::IntegerConst(1), + ], + LuaTupleStatus::InferResolve, + ) + .into(), + ); + let source = LuaType::Variadic(VariadicType::Base(LuaType::ConstTplRef(tpl.clone())).into()); + let target = LuaType::Variadic(VariadicType::Base(literal_tuple.clone()).into()); + + return_type_infer_types( + &mut context, + &source, + &target, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + None, + &InferGuard::new(), + ) + .expect("return variadic inference"); + + let return_type = LuaType::ConstTplRef(tpl.clone()); + let mapper = context.fixing_mapper(std::iter::once(&tpl), &return_type); + assert_eq!( + get_mapped_value(tpl_id, &mapper).and_then(|value| value.raw_type()), + Some(literal_tuple) + ); +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 8ac1a2644..0b11a8c68 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -7,7 +7,7 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, - semantic::generic::type_substitutor::TypeSubstitutor, + semantic::generic::{TypeMapper, TypeMapperValue}, }; use super::instantiate_type_generic; @@ -90,13 +90,15 @@ fn complete_type_generic_args_inner( } let mut params = Vec::with_capacity(generic_params.len().max(provided_args.len())); - let mut substitutor = TypeSubstitutor::new(); + let mut prefix_sources = Vec::with_capacity(generic_params.len()); + let mut prefix_targets = Vec::with_capacity(generic_params.len()); let mut missing_required_count = 0; let mut cycled = false; for (idx, generic_param) in generic_params.iter().enumerate() { if let Some(provided_arg) = provided_args.get(idx) { let provided_arg = provided_arg.clone(); - substitutor.insert_type(GenericTplId::Type(idx as u32), provided_arg.clone(), true); + prefix_sources.push(GenericTplId::Type(idx as u32)); + prefix_targets.push(provided_arg.clone()); params.push(provided_arg); continue; } @@ -114,8 +116,17 @@ fn complete_type_generic_args_inner( } else { completed_type.ty }; - let instantiated = instantiate_type_generic(db, &default_type, &substitutor); - substitutor.insert_type(GenericTplId::Type(idx as u32), instantiated.clone(), true); + let mapper = TypeMapper::from_values( + prefix_sources.clone(), + prefix_targets + .iter() + .cloned() + .map(TypeMapperValue::type_value) + .collect(), + ); + let instantiated = instantiate_type_generic(db, &default_type, &mapper); + prefix_sources.push(GenericTplId::Type(idx as u32)); + prefix_targets.push(instantiated.clone()); params.push(instantiated); } else { missing_required_count += 1; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs new file mode 100644 index 000000000..ee3f81c35 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/context.rs @@ -0,0 +1,115 @@ +use std::sync::Arc; + +use crate::{DbIndex, LuaType, LuaTypeDeclId}; + +use super::super::TypeMapper; + +const MAX_INSTANTIATION_DEPTH: usize = 128; +const MAX_ALIAS_STACK: usize = 32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TplResolvePolicy { + /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. + Fallback, + /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. + PreserveTplRef, +} + +#[derive(Debug)] +pub(in crate::semantic::generic) struct GenericInstantiateContext<'a> { + pub db: &'a DbIndex, + pub mapper: TypeMapper, + self_type: Option<&'a LuaType>, + alias_stack: Arc<[LuaTypeDeclId]>, +} + +#[derive(Debug, Clone, Copy)] +pub(in crate::semantic::generic) struct GenericInstantiateFrame { + policy: TplResolvePolicy, + depth: usize, +} + +impl<'a> GenericInstantiateContext<'a> { + pub(super) fn new( + db: &'a DbIndex, + mapper: &TypeMapper, + self_type: Option<&'a LuaType>, + ) -> Self { + Self { + db, + mapper: mapper.clone(), + self_type, + alias_stack: Arc::from([]), + } + } + + pub(super) fn root_frame(&self) -> GenericInstantiateFrame { + GenericInstantiateFrame { + policy: TplResolvePolicy::Fallback, + depth: 0, + } + } + + pub(super) fn self_type(&self) -> Option<&LuaType> { + self.self_type + } + + pub(super) fn with_mapper(&self, mapper: TypeMapper) -> GenericInstantiateContext<'a> { + GenericInstantiateContext { + db: self.db, + mapper, + self_type: self.self_type, + alias_stack: self.alias_stack.clone(), + } + } + + pub(super) fn enter_alias_stack( + &self, + alias_type_id: &LuaTypeDeclId, + ) -> Option> { + if self.alias_stack.len() >= MAX_ALIAS_STACK + || self.alias_stack.iter().any(|id| id == alias_type_id) + { + return None; + } + + let mut alias_stack = Vec::with_capacity(self.alias_stack.len() + 1); + alias_stack.extend(self.alias_stack.iter().cloned()); + alias_stack.push(alias_type_id.clone()); + Some(Arc::from(alias_stack)) + } + + pub(super) fn with_alias_stack( + &self, + alias_stack: Arc<[LuaTypeDeclId]>, + mapper: TypeMapper, + ) -> GenericInstantiateContext<'a> { + GenericInstantiateContext { + db: self.db, + mapper, + self_type: self.self_type, + alias_stack, + } + } +} + +impl GenericInstantiateFrame { + pub(super) fn with_policy(self, policy: TplResolvePolicy) -> Self { + Self { policy, ..self } + } + + pub(super) fn should_preserve_tpl_ref(&self) -> bool { + self.policy == TplResolvePolicy::PreserveTplRef + } + + pub(super) fn enter(self) -> Option { + if self.depth >= MAX_INSTANTIATION_DEPTH { + return None; + } + + Some(Self { + depth: self.depth + 1, + ..self + }) + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs similarity index 70% rename from crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs rename to crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs index 1c8e39996..fa7c8157b 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs @@ -12,34 +12,35 @@ use crate::{ semantic::{ LuaInferCache, generic::{ - instantiate_type::instantiate_doc_function, - tpl_context::TplContext, - tpl_pattern::{ - multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, - tpl_pattern_match, variadic_tpl_pattern_match, - }, + InferenceContext, InferencePriority, InferenceVariance, infer_type_list, + infer_types_from_expr, multi_param_infer_multi_return, return_type_infer_types, + variadic_infer_types, }, infer::InferFailReason, infer_expr, - overload_resolve::{callable_accepts_args, resolve_signature_by_args}, + overload_resolve::{ + callable_accepts_args, collect_callable_overload_groups, resolve_signature_by_args, + }, }, }; use crate::{ GenericTpl, LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, - collect_callable_overload_groups, infer_node_semantic_decl, - tpl_pattern_match_args_skip_unknown, + infer_node_semantic_decl, }; -use super::{TypeSubstitutor, instantiate_type_generic}; +use super::{ + TplResolvePolicy, TypeMapper, TypeMapperValue, instantiate_type_generic, + instantiate_type_generic_full, +}; -pub fn instantiate_func_generic( +pub fn infer_call_func_generic( db: &DbIndex, cache: &mut LuaInferCache, func: &LuaFunctionType, call_expr: LuaCallExpr, ) -> Result { let file_id = cache.get_file_id().clone(); - let (generic_tpls, contain_self) = collect_func_tpl_ids(func); + let (inference_slots, mapper_tpls, contain_self) = collect_func_generic_info(func); let origin_params = func.get_params(); let mut func_params: Vec<_> = origin_params @@ -52,37 +53,42 @@ pub fn instantiate_func_generic( .ok_or(InferFailReason::None)? .get_args() .collect::>(); - let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: Some(call_expr.clone()), - }; - if !generic_tpls.is_empty() { - context.substitutor.add_need_infer_tpls(generic_tpls); + let mapper; + { + let mut context = InferenceContext::new(db, cache, Some(call_expr.clone())); + if !inference_slots.is_empty() { + context.prepare_inference_slots(inference_slots); - if let Some(type_list) = call_expr.get_call_generic_type_list() { - // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 - apply_call_generic_type_list(db, file_id, &mut context, &type_list); - } else { - // 如果没有指定泛型, 则需要从调用参数中推断 - infer_generic_types_from_call( - db, - &mut context, - func, - &call_expr, - &mut func_params, - &arg_exprs, - )?; + if let Some(type_list) = call_expr.get_call_generic_type_list() { + // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 + apply_call_generic_type_list(db, file_id, &mut context, &type_list); + } else { + // 如果没有指定泛型, 则需要从调用参数中推断 + infer_generic_types_from_call( + db, + &mut context, + func, + &call_expr, + &mut func_params, + &arg_exprs, + )?; + } } - } - if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { - substitutor.add_self_type(self_type); + mapper = context.fixing_mapper(mapper_tpls.iter(), func.get_ret()); } - if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { + let self_type = contain_self + .then(|| infer_self_type(db, cache, &call_expr)) + .flatten(); + let func_ty = LuaType::DocFunction(func.clone().into()); + if let LuaType::DocFunction(f) = instantiate_type_generic_full( + db, + &func_ty, + &mapper, + self_type.as_ref(), + TplResolvePolicy::Fallback, + ) { Ok(f.deref().clone()) } else { Ok(func.clone()) @@ -92,19 +98,17 @@ pub fn instantiate_func_generic( fn apply_call_generic_type_list( db: &DbIndex, file_id: FileId, - context: &mut TplContext, + context: &mut InferenceContext, type_list: &LuaDocTypeList, ) { let doc_ctx = DocTypeInferContext::new(db, file_id); for (i, doc_type) in type_list.get_types().enumerate() { let typ = infer_doc_type(doc_ctx, &doc_type); - context - .substitutor - .insert_type(GenericTplId::Func(i as u32), typ, true); + context.fix_type(GenericTplId::Func(i as u32), typ); } } -pub fn as_doc_function_type( +fn as_doc_function_type( db: &DbIndex, callable_type: &LuaType, ) -> Result>, InferFailReason> { @@ -121,7 +125,7 @@ pub fn as_doc_function_type( } fn infer_callable_return_from_arg_types( - context: &mut TplContext, + context: &mut InferenceContext, callable_type: &LuaType, call_arg_types: &[LuaType], ) -> Result, InferFailReason> { @@ -197,8 +201,8 @@ fn uses_erased_function_param(callable: &LuaFunctionType, call_arg_types: &[LuaT }) } -pub fn infer_callable_return_from_remaining_args( - context: &mut TplContext, +fn infer_callable_return_from_remaining_args( + context: &mut InferenceContext, callable_type: &LuaType, arg_exprs: &[LuaExpr], ) -> Result, InferFailReason> { @@ -228,7 +232,7 @@ pub fn infer_callable_return_from_remaining_args( } fn instantiate_callable_from_arg_types( - context: &mut TplContext, + context: &mut InferenceContext, callable: &Arc, call_arg_types: &[LuaType], ) -> Option> { @@ -236,12 +240,11 @@ fn instantiate_callable_from_arg_types( return None; } - let mut callable_tpls = HashSet::new(); - callable.visit_nested_types(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { - callable_tpls.insert(generic_tpl.get_tpl_id()); - } - }); + let (_, callable_mapper_tpls, _) = collect_func_generic_info(callable); + let callable_tpls = callable_mapper_tpls + .iter() + .map(|tpl| tpl.get_tpl_id()) + .collect::>(); if callable_tpls.is_empty() { return Some(callable.clone()); } @@ -251,28 +254,44 @@ fn instantiate_callable_from_arg_types( .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); - let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls.clone()); - let mut callable_context = TplContext { - db: context.db, - cache: context.cache, - substitutor: &mut callable_substitutor, - call_expr: context.call_expr.clone(), + let (non_fixing_mapper, mapper) = { + let mut callable_context = + InferenceContext::new(context.db, context.cache, context.call_expr.clone()); + callable_context.prepare_inference_slots(callable_tpls.clone()); + if infer_type_list( + &mut callable_context, + &callable_param_types, + call_arg_types, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + ) + .is_err() + { + return None; + } + let non_fixing_mapper = + callable_context.non_fixing_mapper(callable_mapper_tpls.iter(), callable.get_ret()); + let mapper = + callable_context.fixing_mapper(callable_mapper_tpls.iter(), callable.get_ret()); + (non_fixing_mapper, mapper) }; - if tpl_pattern_match_args_skip_unknown( - &mut callable_context, - &callable_param_types, - call_arg_types, - ) - .is_err() - { - return None; - } - let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { - LuaType::DocFunction(func) => func, - _ => callable.clone(), - }; + let return_only_callable = LuaType::DocFunction( + LuaFunctionType::new( + callable.get_async_state(), + callable.is_colon_define(), + false, + Vec::new(), + callable.get_ret().clone(), + ) + .into(), + ); + let instantiated = + match instantiate_type_generic(context.db, &return_only_callable, &non_fixing_mapper) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; let unresolved_return_tpls = { let mut tpl_ids = HashSet::new(); instantiated.get_ret().visit_type(&mut |ty| { @@ -283,7 +302,11 @@ fn instantiate_callable_from_arg_types( } }); if tpl_ids.is_empty() { - return Some(instantiated); + let callable_ty = LuaType::DocFunction(callable.clone()); + return match instantiate_type_generic(context.db, &callable_ty, &mapper) { + LuaType::DocFunction(func) => Some(func), + _ => Some(instantiated), + }; } tpl_ids }; @@ -298,10 +321,12 @@ fn instantiate_callable_from_arg_types( return None; } + let mut mapper = mapper; for tpl_id in callback_return_tpls { - callable_substitutor.insert_type(tpl_id, LuaType::Unknown, true); + mapper = TypeMapper::prepend(tpl_id, LuaType::Unknown, Some(mapper)); } - match instantiate_doc_function(context.db, callable, &callable_substitutor) { + let callable_ty = LuaType::DocFunction(callable.clone()); + match instantiate_type_generic(context.db, &callable_ty, &mapper) { LuaType::DocFunction(func) => Some(func), _ => None, } @@ -359,22 +384,31 @@ fn collect_callback_return_tpls( callback_return_tpls } -fn collect_func_tpl_ids(func: &LuaFunctionType) -> (HashSet, bool) { - let mut generic_tpls = HashSet::new(); +fn collect_func_generic_info( + func: &LuaFunctionType, +) -> (HashSet, Vec>, bool) { + let mut inference_slots = HashSet::new(); + let mut mapper_tpls = Vec::new(); let mut contain_self = false; - func.visit_nested_types(&mut |ty| match ty { LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_func_tpl_with_fallback_deps(generic_tpl, &mut generic_tpls); + collect_func_tpl_with_fallback_deps(generic_tpl, &mut inference_slots); + if generic_tpl.get_tpl_id().is_func() + && !mapper_tpls + .iter() + .any(|it: &Arc| it.get_tpl_id() == generic_tpl.get_tpl_id()) + { + mapper_tpls.push(generic_tpl.clone()); + } } LuaType::StrTplRef(str_tpl) => { - generic_tpls.insert(str_tpl.get_tpl_id()); + inference_slots.insert(str_tpl.get_tpl_id()); } LuaType::SelfInfer => contain_self = true, _ => {} }); - (generic_tpls, contain_self) + (inference_slots, mapper_tpls, contain_self) } fn collect_func_tpl_with_fallback_deps( @@ -462,7 +496,7 @@ fn collect_func_tpl_dep_from_fallback_type( fn infer_generic_types_from_call( db: &DbIndex, - context: &mut TplContext, + context: &mut InferenceContext, func: &LuaFunctionType, call_expr: &LuaCallExpr, func_params: &mut Vec<(String, LuaType)>, @@ -488,7 +522,7 @@ fn infer_generic_types_from_call( break; } - if context.substitutor.is_infer_all_tpl() { + if !context.has_unresolved_inference_slots() { break; } @@ -511,20 +545,33 @@ fn infer_generic_types_from_call( Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 Err(e) => return Err(e), }; - if let Some(return_pattern) = as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) { if let Some(inferred_return_type) = infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? { - return_type_pattern_match_target_type( + return_type_infer_types( context, &return_pattern, &inferred_return_type, + &inferred_return_type, + InferenceVariance::Covariant, + InferencePriority::Return, + None, + &crate::InferGuard::new(), )?; } else if arg_type.is_any() || arg_type.is_unknown() { - return_type_pattern_match_target_type(context, &return_pattern, &LuaType::Unknown)?; + return_type_infer_types( + context, + &return_pattern, + &LuaType::Unknown, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + None, + &crate::InferGuard::new(), + )?; } } @@ -535,7 +582,14 @@ fn infer_generic_types_from_call( let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; arg_types.push(arg_type); } - variadic_tpl_pattern_match(context, variadic, &arg_types)?; + variadic_infer_types( + context, + variadic, + &arg_types, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + )?; break; } (_, LuaType::Variadic(variadic)) => { @@ -544,20 +598,39 @@ fn infer_generic_types_from_call( .map(|(_, t)| t) .cloned() .collect::>(); - multi_param_tpl_pattern_match_multi_return(context, &func_param_types, variadic)?; + multi_param_infer_multi_return( + context, + &func_param_types, + variadic, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + )?; break; } _ => { - tpl_pattern_match(context, func_param_type, &arg_type)?; + infer_types_from_expr( + context, + func_param_type, + &arg_type, + &arg_type, + call_arg_expr, + )?; } } } - if !context.substitutor.is_infer_all_tpl() { + if context.has_unresolved_inference_slots() { for (func_param_type, call_arg_expr) in unresolve_tpls { - let closure_type = infer_expr(db, context.cache, call_arg_expr)?; - - tpl_pattern_match(context, &func_param_type, &closure_type)?; + let closure_type = infer_expr(db, context.cache, call_arg_expr.clone())?; + + infer_types_from_expr( + context, + &func_param_type, + &closure_type, + &closure_type, + &call_arg_expr, + )?; } } @@ -569,11 +642,21 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { LuaType::Def(id) | LuaType::Ref(id) => { if let Some(generic) = db.get_type_index().get_generic_params(id) { let mut params = Vec::with_capacity(generic.len()); - let mut substitutor = TypeSubstitutor::new(); + let mut prefix_sources = Vec::with_capacity(generic.len()); + let mut prefix_targets = Vec::with_capacity(generic.len()); for (i, generic_param) in generic.iter().enumerate() { let tpl_id = GenericTplId::Type(i as u32); - let param = build_self_generic_arg(db, generic_param, &substitutor); - substitutor.insert_type(tpl_id, param.clone(), true); + let mapper = TypeMapper::from_values( + prefix_sources.clone(), + prefix_targets + .iter() + .cloned() + .map(TypeMapperValue::type_value) + .collect(), + ); + let param = build_self_generic_arg(db, generic_param, &mapper); + prefix_sources.push(tpl_id); + prefix_targets.push(param.clone()); params.push(param); } let generic = LuaGenericType::new(id.clone(), params); @@ -588,7 +671,7 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { fn build_self_generic_arg( db: &DbIndex, generic_param: &GenericParam, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> LuaType { let Some(arg) = generic_param .default_type @@ -598,7 +681,7 @@ fn build_self_generic_arg( return LuaType::Unknown; }; - instantiate_type_generic(db, arg, substitutor) + instantiate_type_generic(db, arg, mapper) } pub fn infer_self_type( @@ -651,7 +734,7 @@ pub fn infer_self_type( } fn check_expr_can_later_infer( - context: &mut TplContext, + context: &mut InferenceContext, func_param_type: &LuaType, call_arg_expr: &LuaExpr, ) -> Result { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs new file mode 100644 index 000000000..4ec342ace --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs @@ -0,0 +1,476 @@ +use std::{ops::Deref, sync::Arc}; + +use hashbrown::{HashMap, HashSet}; +use rowan::TextRange; + +use crate::{ + DbIndex, GenericParam, InFiled, LuaArrayType, LuaConditionalType, LuaFunctionType, + LuaGenericType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, + LuaType, LuaUnionType, TypeOps, VariadicType, +}; + +pub(in crate::semantic::generic) fn is_primitive_or_literal_type(ty: &LuaType) -> bool { + match ty { + LuaType::String + | LuaType::Number + | LuaType::Integer + | LuaType::Boolean + | LuaType::StringConst(_) + | LuaType::DocStringConst(_) + | LuaType::IntegerConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::BooleanConst(_) + | LuaType::DocBooleanConst(_) => true, + LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_primitive_or_literal_type), + LuaType::Union(union) => union.into_vec().iter().any(is_primitive_or_literal_type), + LuaType::MultiLineUnion(union) => union + .get_unions() + .iter() + .any(|(ty, _)| is_primitive_or_literal_type(ty)), + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => is_primitive_or_literal_type(base), + VariadicType::Multi(types) => types.iter().any(is_primitive_or_literal_type), + }, + LuaType::Call(call) => call.get_operands().iter().any(is_primitive_or_literal_type), + _ => false, + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum WideningContext { + Root, + RootUnionMember, + UnionMember, + ObjectProperty, + ArrayElement, + TupleElement, + VariadicElement, +} + +const MAX_WIDENING_DEPTH: u16 = 100; + +#[derive(Default)] +struct WideningGuard { + depth: u16, + active_table_ids: HashSet>, +} + +impl WideningGuard { + fn enter_level(&mut self) -> bool { + if self.depth >= MAX_WIDENING_DEPTH { + return false; + } + self.depth += 1; + true + } + + fn leave_level(&mut self) { + self.depth = self.depth.saturating_sub(1); + } + + fn enter_table(&mut self, table_id: &InFiled) -> bool { + self.active_table_ids.insert(table_id.clone()) + } + + fn leave_table(&mut self, table_id: &InFiled) { + self.active_table_ids.remove(table_id); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum RootPrimitiveBehavior { + PreserveLiteral, + WidenLiteral, +} + +struct WideningTransformer<'db> { + db: Option<&'db DbIndex>, + root_primitive_behavior: RootPrimitiveBehavior, + guard: WideningGuard, +} + +impl<'db> WideningTransformer<'db> { + fn new(db: Option<&'db DbIndex>, root_primitive_behavior: RootPrimitiveBehavior) -> Self { + Self { + db, + root_primitive_behavior, + guard: WideningGuard::default(), + } + } + + fn for_candidate_regularization(db: &'db DbIndex) -> Self { + Self::new(Some(db), RootPrimitiveBehavior::PreserveLiteral) + } + + fn for_candidate_widening(db: &'db DbIndex) -> Self { + Self::new(Some(db), RootPrimitiveBehavior::WidenLiteral) + } + + fn transform(&mut self, ty: LuaType, context: WideningContext) -> LuaType { + if !self.guard.enter_level() { + return self.fallback(ty, context); + } + + let widened = match ty { + LuaType::TableConst(table_id) => self.transform_table_const(table_id), + LuaType::Array(array) => self.transform_array(array, context), + LuaType::Tuple(tuple) => self.transform_tuple(tuple), + LuaType::Object(object) => self.transform_object(object), + LuaType::Union(union) => self.transform_union(union, context), + LuaType::MultiLineUnion(multi) => self.transform_multi_line_union(multi, context), + LuaType::Intersection(intersection) => { + self.transform_intersection(intersection, context) + } + LuaType::Variadic(variadic) => self.transform_variadic(variadic), + LuaType::Generic(generic) => self.transform_generic(generic), + LuaType::TableGeneric(params) => self.transform_table_generic(params), + LuaType::DocFunction(func) => self.transform_doc_function(func), + LuaType::TypeGuard(type_guard) => self.transform_type_guard(type_guard), + LuaType::Conditional(conditional) => self.transform_conditional(conditional), + LuaType::Mapped(mapped) => self.transform_mapped(mapped), + ty => self.transform_terminal(ty, context), + }; + + self.guard.leave_level(); + widened + } + + fn fallback(&self, ty: LuaType, context: WideningContext) -> LuaType { + match (self.db, ty) { + (Some(_), LuaType::TableConst(_)) => LuaType::Table, + (_, ty) => self.transform_terminal(ty, context), + } + } + + fn transform_table_const(&mut self, table_id: InFiled) -> LuaType { + let Some(db) = self.db else { + return LuaType::TableConst(table_id); + }; + + self.table_const_to_object(db, table_id) + .unwrap_or(LuaType::Table) + } + + fn transform_array(&mut self, array: Arc, context: WideningContext) -> LuaType { + let element_context = match context { + WideningContext::TupleElement => WideningContext::TupleElement, + _ => WideningContext::ArrayElement, + }; + let base = self.transform(array.get_base().clone(), element_context); + LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) + } + + fn transform_tuple(&mut self, tuple: Arc) -> LuaType { + let types = tuple + .get_types() + .iter() + .cloned() + .map(|ty| self.transform(ty, WideningContext::TupleElement)) + .collect(); + LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + + fn transform_object(&mut self, object: Arc) -> LuaType { + let fields = object + .get_fields() + .iter() + .map(|(key, ty)| { + ( + key.clone(), + self.transform(ty.clone(), WideningContext::ObjectProperty), + ) + }) + .collect(); + let index_access = object + .get_index_access() + .iter() + .map(|(key, value)| { + ( + self.transform(key.clone(), WideningContext::ObjectProperty), + self.transform(value.clone(), WideningContext::ObjectProperty), + ) + }) + .collect(); + LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) + } + + fn transform_union(&mut self, union: Arc, context: WideningContext) -> LuaType { + let member_context = self.union_member_context(context); + LuaType::Union( + LuaUnionType::from_vec( + union + .into_vec() + .into_iter() + .map(|ty| self.transform(ty, member_context)) + .collect(), + ) + .into(), + ) + } + + fn transform_multi_line_union( + &mut self, + multi: Arc, + context: WideningContext, + ) -> LuaType { + let member_context = self.union_member_context(context); + LuaType::MultiLineUnion( + crate::LuaMultiLineUnion::new( + multi + .get_unions() + .iter() + .map(|(ty, description)| { + ( + self.transform(ty.clone(), member_context), + description.clone(), + ) + }) + .collect(), + ) + .into(), + ) + } + + fn transform_intersection( + &mut self, + intersection: Arc, + context: WideningContext, + ) -> LuaType { + let member_context = self.union_member_context(context); + LuaType::Intersection( + crate::LuaIntersectionType::new( + intersection + .get_types() + .iter() + .cloned() + .map(|ty| self.transform(ty, member_context)) + .collect(), + ) + .into(), + ) + } + + fn transform_variadic(&mut self, variadic: Arc) -> LuaType { + LuaType::Variadic( + match variadic.deref() { + VariadicType::Base(base) => VariadicType::Base( + self.transform(base.clone(), WideningContext::VariadicElement), + ), + VariadicType::Multi(types) => VariadicType::Multi( + types + .iter() + .cloned() + .map(|ty| self.transform(ty, WideningContext::VariadicElement)) + .collect(), + ), + } + .into(), + ) + } + + fn transform_generic(&mut self, generic: Arc) -> LuaType { + LuaType::Generic( + LuaGenericType::new( + generic.get_base_type_id(), + generic + .get_params() + .iter() + .cloned() + .map(|ty| self.transform(ty, WideningContext::Root)) + .collect(), + ) + .into(), + ) + } + + fn transform_table_generic(&mut self, params: Arc>) -> LuaType { + LuaType::TableGeneric( + params + .iter() + .cloned() + .map(|ty| self.transform(ty, WideningContext::Root)) + .collect::>() + .into(), + ) + } + + fn transform_doc_function(&mut self, func: Arc) -> LuaType { + LuaType::DocFunction( + LuaFunctionType::new( + func.get_async_state(), + func.is_colon_define(), + func.is_variadic(), + func.get_params() + .iter() + .map(|(name, ty)| { + ( + name.clone(), + ty.clone() + .map(|ty| self.transform(ty, WideningContext::Root)), + ) + }) + .collect(), + self.transform(func.get_ret().clone(), WideningContext::Root), + ) + .into(), + ) + } + + fn transform_type_guard(&mut self, type_guard: Arc) -> LuaType { + LuaType::TypeGuard( + self.transform(type_guard.deref().clone(), WideningContext::Root) + .into(), + ) + } + + fn transform_conditional(&mut self, conditional: Arc) -> LuaType { + LuaType::Conditional( + LuaConditionalType::new( + self.transform( + conditional.get_checked_type().clone(), + WideningContext::Root, + ), + self.transform( + conditional.get_extends_type().clone(), + WideningContext::Root, + ), + self.transform(conditional.get_true_type().clone(), WideningContext::Root), + self.transform(conditional.get_false_type().clone(), WideningContext::Root), + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ) + } + + fn transform_mapped(&mut self, mapped: Arc) -> LuaType { + LuaType::Mapped(Arc::new(LuaMappedType::new( + ( + mapped.param.0, + GenericParam::new( + mapped.param.1.name.clone(), + mapped + .param + .1 + .type_constraint + .clone() + .map(|ty| self.transform(ty, WideningContext::Root)), + mapped + .param + .1 + .default_type + .clone() + .map(|ty| self.transform(ty, WideningContext::Root)), + mapped.param.1.attributes.clone(), + ), + ), + self.transform(mapped.value.clone(), WideningContext::Root), + mapped.is_readonly, + mapped.is_optional, + ))) + } + + fn transform_terminal(&self, ty: LuaType, context: WideningContext) -> LuaType { + // Keep a top-level literal union intact. Widening `"a" | "b"` to `string` + // would throw away a deliberate literal candidate during inference. + if matches!(context, WideningContext::RootUnionMember) { + return ty; + } + + if matches!(context, WideningContext::Root) + && matches!( + self.root_primitive_behavior, + RootPrimitiveBehavior::PreserveLiteral + ) + { + return ty; + } + + widen_primitive_literal(ty) + } + + fn union_member_context(&self, context: WideningContext) -> WideningContext { + if matches!(context, WideningContext::Root) { + WideningContext::RootUnionMember + } else { + WideningContext::UnionMember + } + } + + fn table_const_to_object( + &mut self, + db: &DbIndex, + table_id: InFiled, + ) -> Option { + if !self.guard.enter_table(&table_id) { + return Some(LuaType::Table); + } + + let owner = LuaMemberOwner::Element(table_id.clone()); + let members = match db.get_member_index().get_members(&owner) { + Some(members) => members, + None => { + self.guard.leave_table(&table_id); + return None; + } + }; + let mut fields = HashMap::with_capacity(members.len()); + let mut index_access = Vec::with_capacity(members.len()); + + for member in members { + let value = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Unknown); + let value = self.transform(value, WideningContext::ObjectProperty); + + match member.get_key() { + LuaMemberKey::Name(_) | LuaMemberKey::Integer(_) => { + let member_key = member.get_key().clone(); + fields + .entry(member_key) + .and_modify(|prev| { + *prev = TypeOps::Union.apply(db, prev, &value); + }) + .or_insert(value); + } + LuaMemberKey::ExprType(key) => { + index_access.push(( + self.transform(key.clone(), WideningContext::ObjectProperty), + value, + )); + } + LuaMemberKey::None => {} + } + } + + self.guard.leave_table(&table_id); + + Some(LuaType::Object( + LuaObjectType::new_with_fields(fields, index_access).into(), + )) + } +} + +pub(in crate::semantic::generic) fn regularize_tpl_candidate_type( + db: &DbIndex, + ty: LuaType, +) -> LuaType { + WideningTransformer::for_candidate_regularization(db).transform(ty, WideningContext::Root) +} + +pub(in crate::semantic::generic) fn widen_tpl_candidate_type(db: &DbIndex, ty: LuaType) -> LuaType { + WideningTransformer::for_candidate_widening(db).transform(ty, WideningContext::Root) +} + +fn widen_primitive_literal(ty: LuaType) -> LuaType { + match ty { + LuaType::FloatConst(_) => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, + LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, + ty => ty, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index 226a0fb4f..3da8f3231 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -1,5 +1,4 @@ use hashbrown::{HashMap, HashSet}; -use std::ops::Deref; use crate::{ DbIndex, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, @@ -8,16 +7,13 @@ use crate::{ semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; -use super::{ - get_default_constructor, instantiate_type_generic, instantiate_type_generic_with_context, -}; -use crate::semantic::generic::type_substitutor::GenericInstantiateContext; +use super::{GenericInstantiateContext, GenericInstantiateFrame}; +use super::{get_default_constructor, instantiate_type_generic_inner}; +use crate::semantic::generic::{TypeMapper, get_mapped_value}; #[derive(Debug, Clone, Copy)] enum InferVariance { - // 协变 Covariant, - // 逆变 Contravariant, } @@ -38,27 +34,35 @@ struct InferCandidateSet { pub(super) fn instantiate_conditional( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> LuaType { - if let Some(distributed) = instantiate_distributed_conditional(context, conditional) { + let Some(frame) = frame.enter() else { + return instantiate_conditional_residual(context, frame, conditional, None, None); + }; + + if let Some(distributed) = instantiate_distributed_conditional(context, frame, conditional) { return distributed; } - instantiate_conditional_once(context, conditional) + instantiate_conditional_once(context, frame, conditional) } fn instantiate_conditional_once( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> LuaType { let left_type = instantiate_conditional_operand( context, + frame, conditional.get_checked_type(), true, conditional.has_new, ); let right_type = instantiate_conditional_operand( context, + frame, conditional.get_extends_type(), false, conditional.has_new, @@ -78,115 +82,166 @@ fn instantiate_conditional_once( ) { instantiate_true_branch( context, + frame, conditional, finalize_infer_assignments(infer_assignments), ) - } else { - instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, + } else if is_deferred_conditional_operand(&left_type) + || right_type.any_type(|inner| match inner { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + !tpl.get_tpl_id().is_conditional_infer() + } + LuaType::StrTplRef(_) + | LuaType::SelfInfer + | LuaType::Conditional(_) + | LuaType::Mapped(_) + | LuaType::Call(_) => true, + _ => false, + }) + { + instantiate_conditional_residual( + context, + frame, + conditional, + Some(left_type), + Some(right_type), ) + } else { + instantiate_type_generic_inner(context, frame, conditional.get_false_type()) }; } match check_conditional_extends(context.db, &left_type, &right_type) { - ConditionalCheck::True => instantiate_true_branch(context, conditional, HashMap::new()), - ConditionalCheck::False => instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, - ), + ConditionalCheck::True => { + instantiate_true_branch(context, frame, conditional, HashMap::new()) + } + ConditionalCheck::False => { + instantiate_type_generic_inner(context, frame, conditional.get_false_type()) + } ConditionalCheck::Both => { - let true_type = instantiate_true_branch(context, conditional, HashMap::new()); - let false_type = instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, - ); + if is_deferred_conditional_operand(&left_type) + || is_deferred_conditional_operand(&right_type) + { + return instantiate_conditional_residual( + context, + frame, + conditional, + Some(left_type), + Some(right_type), + ); + } + let true_type = instantiate_true_branch(context, frame, conditional, HashMap::new()); + let false_type = + instantiate_type_generic_inner(context, frame, conditional.get_false_type()); TypeOps::Union.apply(context.db, &true_type, &false_type) } } } +fn instantiate_conditional_residual( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + conditional: &LuaConditionalType, + checked_type: Option, + extends_type: Option, +) -> LuaType { + let instantiate_branch = |branch: &LuaType| { + if branch.any_type(|ty| match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + get_mapped_value(tpl.get_tpl_id(), &context.mapper).is_some() + } + LuaType::SelfInfer => context.self_type().is_some(), + _ => false, + }) { + instantiate_type_generic_inner(context, frame, branch) + } else { + branch.clone() + } + }; + + LuaType::Conditional( + LuaConditionalType::new( + checked_type.unwrap_or_else(|| { + instantiate_type_generic_inner(context, frame, conditional.get_checked_type()) + }), + extends_type.unwrap_or_else(|| { + instantiate_type_generic_inner(context, frame, conditional.get_extends_type()) + }), + instantiate_branch(conditional.get_true_type()), + instantiate_branch(conditional.get_false_type()), + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ) +} + /// 处理分布式条件类型, 与`TS`中的分布式条件类型处理方式相同, 只有裸模版参数才会被分布式. fn instantiate_distributed_conditional( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> Option { - let tpl_id = naked_checked_type_tpl_id(conditional.get_checked_type())?; - let raw_checked_type = context.substitutor.get_raw_type(tpl_id)?; + let tpl_id = match conditional.get_checked_type() { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if tpl.get_tpl_id().is_type() || tpl.get_tpl_id().is_func() => + { + tpl.get_tpl_id() + } + _ => return None, + }; + let raw_checked_type = get_mapped_value(tpl_id, &context.mapper)?.raw_type()?; if raw_checked_type.is_never() { return Some(LuaType::Never); } - let members = union_members(raw_checked_type)?; + let members = match &raw_checked_type { + LuaType::Union(union) => union.into_vec(), + LuaType::MultiLineUnion(multi) => multi + .get_unions() + .iter() + .map(|(member, _)| member.clone()) + .collect(), + _ => return None, + }; let mut result = LuaType::Never; for member in members { - let mut member_substitutor = context.substitutor.clone(); - member_substitutor.replace_type(tpl_id, member, false); - let member_context = context.with_substitutor(&member_substitutor); - let member_result = instantiate_conditional_once(&member_context, conditional); + let member_mapper = TypeMapper::prepend(tpl_id, member, Some(context.mapper.clone())); + let member_context = context.with_mapper(member_mapper); + let member_result = instantiate_conditional_once(&member_context, frame, conditional); result = TypeOps::Union.apply(context.db, &result, &member_result); } Some(result) } -fn naked_checked_type_tpl_id(checked_type: &LuaType) -> Option { - match checked_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { - Some(tpl.get_tpl_id()) - } - _ => None, - } -} - -fn union_members(ty: &LuaType) -> Option> { - match ty { - LuaType::Union(union) => Some(union.into_vec()), - LuaType::MultiLineUnion(multi) => Some( - multi - .get_unions() - .iter() - .map(|(member, _)| member.clone()) - .collect(), - ), - _ => None, - } -} - fn instantiate_true_branch( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, infer_assignments: HashMap, ) -> LuaType { if infer_assignments.is_empty() { - return instantiate_type_generic( - context.db, - conditional.get_true_type(), - context.substitutor, - ); + return instantiate_type_generic_inner(context, frame, conditional.get_true_type()); } - let mut true_substitutor = context.substitutor.clone(); + let mut true_mapper = context.mapper.clone(); for (tpl_id, ty) in infer_assignments { - true_substitutor.insert_conditional_infer_type(tpl_id, ty); + true_mapper = TypeMapper::prepend(tpl_id, ty, Some(true_mapper)); } - instantiate_type_generic(context.db, conditional.get_true_type(), &true_substitutor) + let true_context = context.with_mapper(true_mapper); + instantiate_type_generic_inner(&true_context, frame, conditional.get_true_type()) } fn contains_conditional_infer(ty: &LuaType) -> bool { - ty.any_type(conditional_infer_tpl_id) -} - -fn conditional_infer_tpl_id(ty: &LuaType) -> bool { - matches!( - ty, - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if tpl.get_tpl_id().is_conditional_infer() - ) + ty.any_type(|inner| { + matches!( + inner, + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if tpl.get_tpl_id().is_conditional_infer() + ) + }) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -217,6 +272,10 @@ fn check_conditional_extends(db: &DbIndex, source: &LuaType, target: &LuaType) - return ConditionalCheck::True; } + if literal_extends_base_type(source, target) { + return ConditionalCheck::True; + } + if let LuaType::Union(union) = source { let mut result = ConditionalCheck::False; for member in union.into_vec() { @@ -241,6 +300,10 @@ fn check_conditional_extends(db: &DbIndex, source: &LuaType, target: &LuaType) - return ConditionalCheck::False; } + if is_deferred_conditional_operand(source) || is_deferred_conditional_operand(target) { + return ConditionalCheck::Both; + } + if check_type_compact_with_level( db, source, @@ -263,6 +326,25 @@ fn merge_conditional_check(left: ConditionalCheck, right: ConditionalCheck) -> C } } +fn literal_extends_base_type(source: &LuaType, target: &LuaType) -> bool { + matches!( + (source, target), + ( + LuaType::StringConst(_) | LuaType::DocStringConst(_), + LuaType::String + ) | ( + LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_), + LuaType::Integer + ) | ( + LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_) | LuaType::FloatConst(_), + LuaType::Number, + ) | ( + LuaType::BooleanConst(_) | LuaType::DocBooleanConst(_), + LuaType::Boolean + ) + ) +} + fn collect_infer_assignments( db: &DbIndex, source: &LuaType, @@ -653,29 +735,30 @@ fn finalize_infer_assignments( candidates .covariant .or(candidates.contravariant) - .map(|ty| (tpl_id, ty)) + .map(|raw_candidate| (tpl_id, raw_candidate)) }) .collect() } fn instantiate_conditional_operand( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, operand: &LuaType, checked: bool, has_new: bool, ) -> LuaType { - let mut result = instantiate_type_generic_with_context(context, operand); + let mut result = instantiate_type_generic_inner(context, frame, operand); if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); - if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { - result = raw.clone(); - } else if checked && result.contains_tpl_node() { - result = LuaType::Unknown; + if let Some(raw) = + get_mapped_value(tpl_id, &context.mapper).and_then(|value| value.raw_type()) + { + result = raw; + } else if checked && result.is_never() { + result = LuaType::Never; } } - result = actualize_unresolved_templates(result); - if has_new && let LuaType::Ref(id) | LuaType::Def(id) = &result && let Some(decl) = context.db.get_type_index().get_type_decl(id) @@ -688,147 +771,17 @@ fn instantiate_conditional_operand( result } -// 条件类型判定只消费已经实例化后的实际类型, 残留的普通模板引用在这里递归收敛为 `unknown`. -// `infer` pattern 也以模板引用表示, 必须保留下来供后续结构匹配绑定. -fn actualize_unresolved_templates(ty: LuaType) -> LuaType { - match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_conditional_infer() { - // Conditional infer 是右侧 pattern 的占位孔, 不能像普通未解模板一样抹成 unknown. - LuaType::TplRef(tpl) - } else { - LuaType::Unknown - } - } - LuaType::StrTplRef(_) => LuaType::Unknown, - LuaType::Array(array) => LuaType::Array( - crate::LuaArrayType::new( - actualize_unresolved_templates(array.get_base().clone()), - array.get_len().clone(), - ) - .into(), - ), - LuaType::Tuple(tuple) => LuaType::Tuple( - LuaTupleType::new( - tuple - .get_types() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - tuple.status, - ) - .into(), - ), - LuaType::DocFunction(func) => LuaType::DocFunction( - crate::LuaFunctionType::new( - func.get_async_state(), - func.is_colon_define(), - func.is_variadic(), - func.get_params() - .iter() - .map(|(name, ty)| { - (name.clone(), ty.clone().map(actualize_unresolved_templates)) - }) - .collect(), - actualize_unresolved_templates(func.get_ret().clone()), - ) - .into(), - ), - LuaType::Object(object) => LuaType::Object( - LuaObjectType::new_with_fields( - object - .get_fields() - .iter() - .map(|(key, ty)| (key.clone(), actualize_unresolved_templates(ty.clone()))) - .collect(), - object - .get_index_access() - .iter() - .map(|(key, value)| { - ( - actualize_unresolved_templates(key.clone()), - actualize_unresolved_templates(value.clone()), - ) - }) - .collect(), - ) - .into(), - ), - LuaType::Union(union) => LuaType::from_vec( - union - .into_vec() - .into_iter() - .map(actualize_unresolved_templates) - .collect(), - ), - LuaType::MultiLineUnion(multi) => LuaType::from_vec( - multi - .get_unions() - .iter() - .map(|(ty, _)| actualize_unresolved_templates(ty.clone())) - .collect(), - ), - LuaType::Intersection(intersection) => LuaType::Intersection( - crate::LuaIntersectionType::new( - intersection - .get_types() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ) - .into(), - ), - LuaType::Generic(generic) => LuaType::Generic( - crate::LuaGenericType::new( - generic.get_base_type_id(), - generic - .get_params() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ) - .into(), - ), - LuaType::TableGeneric(params) => LuaType::TableGeneric( - params - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect::>() - .into(), - ), - LuaType::Variadic(variadic) => LuaType::Variadic( - match variadic.deref() { - crate::VariadicType::Base(base) => { - crate::VariadicType::Base(actualize_unresolved_templates(base.clone())) - } - crate::VariadicType::Multi(types) => crate::VariadicType::Multi( - types - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ), - } - .into(), - ), - LuaType::TypeGuard(guard) => { - LuaType::TypeGuard(actualize_unresolved_templates(guard.deref().clone()).into()) - } - LuaType::Conditional(conditional) => LuaType::Conditional( - LuaConditionalType::new( - actualize_unresolved_templates(conditional.get_checked_type().clone()), - actualize_unresolved_templates(conditional.get_extends_type().clone()), - actualize_unresolved_templates(conditional.get_true_type().clone()), - actualize_unresolved_templates(conditional.get_false_type().clone()), - conditional.get_infer_params().to_vec(), - conditional.has_new, - ) - .into(), - ), - ty => ty, - } +fn is_deferred_conditional_operand(ty: &LuaType) -> bool { + ty.any_type(|inner| { + matches!( + inner, + LuaType::TplRef(_) + | LuaType::ConstTplRef(_) + | LuaType::StrTplRef(_) + | LuaType::SelfInfer + | LuaType::Conditional(_) + | LuaType::Mapped(_) + | LuaType::Call(_) + ) + }) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs new file mode 100644 index 000000000..bbf9f25c4 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs @@ -0,0 +1,247 @@ +use hashbrown::{HashMap, HashSet}; +use std::ops::Deref; + +use crate::{ + GenericParam, LuaAliasCallKind, LuaMappedType, LuaMemberKey, LuaObjectType, LuaTupleStatus, + LuaTupleType, LuaType, TypeOps, VariadicType, +}; + +use super::{ + GenericInstantiateContext, GenericInstantiateFrame, instantiate_special_generic, + instantiate_type_generic_inner, key_type_to_member_key, +}; +use crate::semantic::generic::TypeMapper; + +pub(super) fn instantiate_mapped_type( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + mapped: &LuaMappedType, +) -> LuaType { + let Some(frame) = frame.enter() else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let Some(constraint) = mapped.param.1.type_constraint.as_ref() else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let Some(key_domain) = resolve_mapped_key_domain(context, frame, constraint) else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let empty_object = + || LuaType::Object(LuaObjectType::new_with_fields(HashMap::new(), Vec::new()).into()); + + if key_domain.keys.is_empty() { + return empty_object(); + } + + let key_count = key_domain.keys.len(); + let mut visited = HashSet::with_capacity(key_count); + let mut field_indices: HashMap = HashMap::with_capacity(key_count); + let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::with_capacity(key_count); + let mut index_access: Vec<(LuaType, LuaType)> = Vec::with_capacity(key_count); + for key_ty in key_domain.keys { + if !visited.insert(key_ty.clone()) { + continue; + } + + let local_mapper = + TypeMapper::prepend(mapped.param.0, key_ty.clone(), Some(context.mapper.clone())); + let local_context = context.with_mapper(local_mapper); + let mut value_ty = instantiate_type_generic_inner(&local_context, frame, &mapped.value); + if mapped.is_optional { + value_ty = TypeOps::Union.apply(context.db, &value_ty, &LuaType::Nil); + } + + if let Some(member_key) = key_type_to_member_key(&key_ty) { + if let Some(index) = field_indices.get(&member_key).copied() { + let (_, existing) = &mut fields[index]; + let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); + *existing = merged; + } else { + field_indices.insert(member_key.clone(), fields.len()); + fields.push((member_key, value_ty)); + } + } else { + index_access.push((key_ty, value_ty)); + } + } + + if fields.is_empty() && index_access.is_empty() { + return empty_object(); + } + + if key_domain.tuple_like + && index_access.is_empty() + && let Some(types) = mapped_tuple_types(&fields) + { + return LuaType::Tuple(LuaTupleType::new(types, LuaTupleStatus::InferResolve).into()); + } + + let field_map: HashMap = fields.into_iter().collect(); + LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()) +} + +struct MappedKeyDomain { + keys: Vec, + tuple_like: bool, +} + +fn resolve_mapped_key_domain( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + constraint: &LuaType, +) -> Option { + if let LuaType::Call(alias_call) = constraint + && alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 + { + let source = instantiate_type_generic_inner(context, frame, &alias_call.get_operands()[0]); + let keys = instantiate_special_generic::get_keyof_type(context.db, &source)?; + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&keys, &mut atoms) { + return None; + } + return Some(MappedKeyDomain { + keys: atoms, + tuple_like: source.is_tuple() || matches!(source, LuaType::Variadic(_)), + }); + } + + let instantiated = instantiate_type_generic_inner(context, frame, constraint); + match &instantiated { + LuaType::Call(alias_call) + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 => + { + let source = &alias_call.get_operands()[0]; + let keys = instantiate_special_generic::get_keyof_type(context.db, source)?; + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&keys, &mut atoms) { + return None; + } + Some(MappedKeyDomain { + keys: atoms, + tuple_like: source.is_tuple() || matches!(source, LuaType::Variadic(_)), + }) + } + _ => { + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&instantiated, &mut atoms) { + return None; + } + Some(MappedKeyDomain { + tuple_like: instantiated.is_tuple(), + keys: atoms, + }) + } + } +} + +fn instantiate_mapped_residual( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + mapped: &LuaMappedType, +) -> LuaType { + let param = ( + mapped.param.0, + GenericParam::new( + mapped.param.1.name.clone(), + mapped + .param + .1 + .type_constraint + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, frame, ty)), + mapped + .param + .1 + .default_type + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, frame, ty)), + mapped.param.1.attributes.clone(), + ), + ); + + LuaType::Mapped( + LuaMappedType::new( + param, + instantiate_type_generic_inner(context, frame, &mapped.value), + mapped.is_readonly, + mapped.is_optional, + ) + .into(), + ) +} + +fn mapped_tuple_types(fields: &[(LuaMemberKey, LuaType)]) -> Option> { + let mut indexed = fields + .iter() + .filter_map(|(key, ty)| match key { + LuaMemberKey::Integer(i) => Some((*i, ty.clone())), + _ => None, + }) + .collect::>(); + + if indexed.len() != fields.len() { + return None; + } + + indexed.sort_by_key(|(index, _)| *index); + let starts_at_zero = indexed.first().is_some_and(|(index, _)| *index == 0); + let expected_start = if starts_at_zero { 0 } else { 1 }; + for (offset, (index, _)) in indexed.iter().enumerate() { + if *index != expected_start + offset as i64 { + return None; + } + } + + Some(indexed.into_iter().map(|(_, ty)| ty).collect()) +} + +fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) -> bool { + match key_ty { + LuaType::Union(union) => { + for member in union.into_vec() { + if !collect_mapped_key_atoms(&member, acc) { + return false; + } + } + true + } + LuaType::MultiLineUnion(multi) => { + for (member, _) in multi.get_unions() { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), + VariadicType::Multi(types) => { + for member in types { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + }, + LuaType::Tuple(tuple) => { + for member in tuple.get_types() { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + LuaType::Never => true, + LuaType::Unknown | LuaType::Call(_) | LuaType::Mapped(_) => false, + _ => { + acc.push(key_ty.clone()); + true + } + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 57fae183d..916245e38 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -1,6 +1,6 @@ use crate::{ DbIndex, LuaAliasCallKind, LuaAliasCallType, LuaMemberInfo, LuaMemberKey, LuaObjectType, - LuaTupleStatus, LuaTupleType, LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, + LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, semantic::{ generic::key_type_to_member_key, member::{find_members, infer_raw_member_type}, @@ -10,16 +10,18 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_with_context}; +use super::{GenericInstantiateContext, GenericInstantiateFrame, instantiate_type_generic_inner}; +use crate::semantic::generic::get_mapped_value; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, alias_call: &LuaAliasCallType, ) -> LuaType { let operand_exprs = alias_call.get_operands(); let operands = operand_exprs .iter() - .map(|it| instantiate_type_generic_with_context(context, it)) + .map(|it| instantiate_type_generic_inner(context, frame, it)) .collect::>(); match alias_call.get_call_kind() { @@ -42,16 +44,12 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let members = get_keyof_members(context.db, &operands[0]).unwrap_or_default(); - let member_key_types = members - .iter() - .filter_map(|m| match &m.key { - LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), - LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), - _ => None, - }) - .collect::>(); - LuaType::Tuple(LuaTupleType::new(member_key_types, LuaTupleStatus::InferResolve).into()) + match get_keyof_type(context.db, &operands[0]) { + Some(key_type) => key_type, + None => { + LuaType::Call(LuaAliasCallType::new(LuaAliasCallKind::KeyOf, operands).into()) + } + } } // 条件类型不在此处理 LuaAliasCallKind::Extends => { @@ -82,7 +80,7 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let key = resolve_literal_operand(operand_exprs.get(1), context.substitutor) + let key = resolve_literal_operand(operand_exprs.get(1), context) .unwrap_or_else(|| operands[1].clone()); instantiate_rawget_call(context.db, &operands[0], &key) @@ -92,7 +90,7 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let key = resolve_literal_operand(operand_exprs.get(1), context.substitutor) + let key = resolve_literal_operand(operand_exprs.get(1), context) .unwrap_or_else(|| operands[1].clone()); instantiate_index_call(context.db, &operands[0], &key) @@ -101,6 +99,29 @@ pub(super) fn instantiate_alias_call( } } +pub(super) fn get_keyof_type(db: &DbIndex, ty: &LuaType) -> Option { + let members = get_keyof_members(db, ty)?; + let member_key_types = members + .iter() + .filter_map(|m| match &m.key { + LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), + LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), + LuaMemberKey::ExprType(typ) => Some(typ.clone()), + _ => None, + }) + .collect::>(); + + if member_key_types.is_empty() { + if members.is_empty() { + return Some(LuaType::Never); + } + + return None; + } + + Some(LuaType::from_vec(member_key_types)) +} + fn instantiate_merge_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.len() != 2 { return LuaType::Unknown; @@ -132,11 +153,12 @@ fn instantiate_merge_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { fn resolve_literal_operand( operand: Option<&LuaType>, - substitutor: &TypeSubstitutor, + context: &GenericInstantiateContext, ) -> Option { match operand { Some(LuaType::TplRef(tpl_ref)) | Some(LuaType::ConstTplRef(tpl_ref)) => { - substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned() + get_mapped_value(tpl_ref.get_tpl_id(), &context.mapper) + .and_then(|value| value.raw_type()) } _ => None, } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index ae9d07e77..ec1c1c430 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,172 +1,188 @@ mod complete_generic_args; +mod context; +mod infer_call_func_generic; +mod inference_widening; mod instantiate_conditional_generic; -mod instantiate_func_generic; +mod instantiate_mapped_type; mod instantiate_special_generic; -use hashbrown::{HashMap, HashSet}; -use std::{ops::Deref, sync::Arc}; +use hashbrown::HashMap; +use std::ops::Deref; use crate::{ - DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, - LuaOperatorMetaMethod, LuaSignatureId, LuaTupleStatus, LuaTupleType, LuaTypeDeclId, - LuaTypeNode, TypeOps, + DbIndex, GenericTpl, LuaArrayType, LuaMemberKey, LuaOperatorMetaMethod, LuaSignatureId, + LuaTupleType, LuaTypeDeclId, LuaTypeNode, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaType, LuaUnionType, VariadicType, }, - semantic::infer::InferFailReason, }; -use super::type_substitutor::{ - GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, UninferredTplPolicy, -}; +use super::{TypeMapper, TypeMapperValue, get_mapped_value}; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; -pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; +pub use context::TplResolvePolicy; +use context::{GenericInstantiateContext, GenericInstantiateFrame}; +pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infer_self_type}; +pub(in crate::semantic::generic) use inference_widening::{ + is_primitive_or_literal_type, regularize_tpl_candidate_type, widen_tpl_candidate_type, +}; +use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_inner; pub use instantiate_special_generic::get_keyof_members; -pub(crate) fn collect_callable_overload_groups( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, -) -> Result<(), InferFailReason> { - let mut visiting_aliases = HashSet::new(); - collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +pub fn instantiate_type_generic(db: &DbIndex, ty: &LuaType, mapper: &TypeMapper) -> LuaType { + instantiate_type_generic_full(db, ty, mapper, None, TplResolvePolicy::Fallback) } -fn collect_callable_overload_groups_inner( +pub fn instantiate_type_generic_full( db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, - visiting_aliases: &mut HashSet, -) -> Result<(), InferFailReason> { - match callable_type { - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { - return Ok(()); - }; - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } + ty: &LuaType, + mapper: &TypeMapper, + self_type: Option<&LuaType>, + root_policy: TplResolvePolicy, +) -> LuaType { + let context = GenericInstantiateContext::new(db, mapper, self_type); + let frame = context.root_frame().with_policy(root_policy); + match ty { + LuaType::DocFunction(doc_func) => instantiate_doc_function(&context, frame, doc_func), + _ => instantiate_type_generic_inner(&context, frame, ty), + } +} - let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(type_id); - result?; +pub(super) fn instantiate_type_generic_inner( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + ty: &LuaType, +) -> LuaType { + if is_simple_instantiate_leaf(ty) { + return ty.clone(); + } + + let Some(frame) = frame.enter() else { + return ty.clone(); + }; + + match ty { + LuaType::Array(array_type) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_array(context, frame, array_type.get_base()) } - LuaType::Generic(generic) => { - let type_id = generic.get_base_type_id(); - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); + LuaType::Tuple(tuple) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { - visiting_aliases.remove(&type_id); - return Ok(()); - }; - - let result = if let Some(origin_type) = - type_decl.get_alias_origin(db, Some(&substitutor)) - { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(&type_id); - result?; + instantiate_tuple(context, frame, tuple) + } + LuaType::DocFunction(doc_func) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_doc_function( + context, + frame.with_policy(TplResolvePolicy::PreserveTplRef), + doc_func, + ) + } + LuaType::Object(object) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_object(context, frame, object) } LuaType::Union(union) => { - for member in union.into_vec() { - collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + if !requires_instantiation_walk(ty) { + return ty.clone(); } + instantiate_union(context, frame, union) } LuaType::Intersection(intersection) => { - for member in intersection.get_types() { - collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + if !requires_instantiation_walk(ty) { + return ty.clone(); } + instantiate_intersection(context, frame, intersection) } - LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), - LuaType::Signature(sig_id) => { - let Some(signature) = db.get_signature_index().get(sig_id) else { - return Ok(()); - }; - let mut overloads = signature.overloads.to_vec(); - overloads.push(signature.to_doc_func_type()); - groups.push(overloads); + LuaType::Generic(generic) => instantiate_generic(context, frame, generic), + LuaType::TableGeneric(table_params) => { + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + instantiate_table_generic(context, frame, table_params) } - _ => {} - } - - Ok(()) -} - -pub fn instantiate_type_generic( - db: &DbIndex, - ty: &LuaType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_type_generic_with_context(&context, ty) -} - -pub(super) fn instantiate_type_generic_with_context( - context: &GenericInstantiateContext, - ty: &LuaType, -) -> LuaType { - match ty { - LuaType::Array(array_type) => instantiate_array(context, array_type.get_base()), - LuaType::Tuple(tuple) => instantiate_tuple(context, tuple), - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context( - &context.with_policy(UninferredTplPolicy::PreserveTplRef), - doc_func, - ), - LuaType::Object(object) => instantiate_object(context, object), - LuaType::Union(union) => instantiate_union(context, union), - LuaType::Intersection(intersection) => instantiate_intersection(context, intersection), - LuaType::Generic(generic) => instantiate_generic_with_context(context, generic), - LuaType::TableGeneric(table_params) => instantiate_table_generic(context, table_params), - LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context), - LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context), - LuaType::Signature(sig_id) => instantiate_signature(context, sig_id), + LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context, frame), + LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context, frame), + LuaType::Signature(sig_id) => instantiate_signature(context, frame, sig_id), LuaType::Call(alias_call) => { - instantiate_special_generic::instantiate_alias_call(context, alias_call) + instantiate_special_generic::instantiate_alias_call(context, frame, alias_call) } - LuaType::Variadic(variadic) => instantiate_variadic_type(context, variadic), + LuaType::Variadic(variadic) => instantiate_variadic_type(context, frame, variadic), LuaType::SelfInfer => { - if let Some(typ) = context.substitutor.get_self_type() { + if let Some(typ) = context.self_type() { typ.clone() } else { LuaType::SelfInfer } } LuaType::TypeGuard(guard) => { - let inner = instantiate_type_generic_with_context(context, guard.deref()); + if !requires_instantiation_walk(ty) { + return ty.clone(); + } + let inner = instantiate_type_generic_inner(context, frame, guard.deref()); LuaType::TypeGuard(inner.into()) } LuaType::Conditional(conditional) => { - instantiate_conditional_generic::instantiate_conditional(context, conditional) + instantiate_conditional_generic::instantiate_conditional(context, frame, conditional) } - LuaType::Mapped(mapped) => instantiate_mapped_type(context, mapped.deref()), + LuaType::Mapped(mapped) => instantiate_mapped_type_inner(context, frame, mapped.deref()), _ => ty.clone(), } } -fn instantiate_types<'a, I>(context: &GenericInstantiateContext, types: I) -> Vec +fn requires_instantiation_walk(ty: &LuaType) -> bool { + match ty { + LuaType::TplRef(_) + | LuaType::StrTplRef(_) + | LuaType::ConstTplRef(_) + | LuaType::SelfInfer + | LuaType::Generic(_) + | LuaType::Signature(_) + | LuaType::Call(_) + | LuaType::Conditional(_) + | LuaType::Mapped(_) => true, + LuaType::Array(array_type) => requires_instantiation_walk(array_type.get_base()), + LuaType::Tuple(tuple) => tuple.any_type(requires_instantiation_walk), + LuaType::DocFunction(doc_func) => doc_func.any_type(requires_instantiation_walk), + LuaType::Object(object) => object.any_type(requires_instantiation_walk), + LuaType::Union(union) => union.any_type(requires_instantiation_walk), + LuaType::Intersection(intersection) => intersection.any_type(requires_instantiation_walk), + LuaType::TableGeneric(table_params) => table_params.iter().any(requires_instantiation_walk), + LuaType::Variadic(variadic) => variadic.any_type(requires_instantiation_walk), + LuaType::TypeGuard(guard) => requires_instantiation_walk(guard.deref()), + LuaType::MultiLineUnion(inner) => inner.any_type(requires_instantiation_walk), + LuaType::DocAttribute(attr) => attr.any_type(requires_instantiation_walk), + _ => false, + } +} + +fn instantiate_types<'a, I>( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + types: I, +) -> Vec where I: IntoIterator, { types .into_iter() - .map(|ty| instantiate_type_generic_with_context(context, ty)) + .map(|ty| instantiate_type_generic_inner(context, frame, ty)) .collect() } fn instantiate_type_pairs<'a, I>( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, pairs: I, ) -> Vec<(LuaType, LuaType)> where @@ -176,41 +192,47 @@ where .into_iter() .map(|(key, value)| { ( - instantiate_type_generic_with_context(context, key), - instantiate_type_generic_with_context(context, value), + instantiate_type_generic_inner(context, frame, key), + instantiate_type_generic_inner(context, frame, value), ) }) .collect() } -fn instantiate_array(context: &GenericInstantiateContext, base: &LuaType) -> LuaType { - let base = instantiate_type_generic_with_context(context, base); +fn instantiate_array( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + base: &LuaType, +) -> LuaType { + let base = instantiate_type_generic_inner(context, frame, base); LuaType::Array(LuaArrayType::from_base_type(base).into()) } -fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) -> LuaType { +fn instantiate_tuple( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + tuple: &LuaTupleType, +) -> LuaType { let mut new_types = Vec::new(); for t in tuple.get_types() { if let LuaType::Variadic(inner) = t { match inner.deref() { VariadicType::Base(base) => { if let LuaType::TplRef(tpl) = base { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => new_types - .push(instantiate_uninferred_tpl_fallback(tpl, context)), - SubstitutorValue::MultiTypes(types) => { - for typ in types { - new_types.push(typ.clone()); - } - } - SubstitutorValue::Params(params) => { + TypeMapperValue::None => new_types + .push(instantiate_uninferred_tpl_fallback(tpl, context, frame)), + TypeMapperValue::Params(params) => { for (_, ty) in params { - new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); + new_types.push(ty.unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => new_types.push(ty.default().clone()), - SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), + TypeMapperValue::MultiTypes(values) => { + new_types.extend(values); + } + TypeMapperValue::Type(value) => new_types.push(value), + TypeMapperValue::MultiBase(base) => new_types.push(base), } } else { new_types.push(LuaType::Variadic(inner.clone())); @@ -223,25 +245,21 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) break; } - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, frame, t); new_types.push(t); } LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } -pub fn instantiate_doc_function( - db: &DbIndex, - doc_func: &LuaFunctionType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_doc_function_with_context(&context, doc_func) -} - -fn instantiate_doc_function_with_context( +fn instantiate_doc_function( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, doc_func: &LuaFunctionType, ) -> LuaType { + if !doc_func.any_type(requires_instantiation_walk) { + return LuaType::DocFunction(doc_func.clone().into()); + } + let tpl_func_params = doc_func.get_params(); let tpl_ret = doc_func.get_ret(); let async_state = doc_func.get_async_state(); @@ -258,19 +276,19 @@ fn instantiate_doc_function_with_context( match origin_param_type { LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => match base { - LuaType::TplRef(tpl) => { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { - let ty = instantiate_uninferred_tpl_fallback(tpl, context); + TypeMapperValue::None => { + let ty = + instantiate_uninferred_tpl_fallback(tpl, context, frame); new_params.push((origin_param.0.clone(), Some(ty))); } - SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + TypeMapperValue::Type(resolved_type) => { // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple - if let LuaType::Tuple(tuple) = resolved_type { + if let LuaType::Tuple(tuple) = &resolved_type { let base_index = new_params.len(); for (i, typ) in tuple.get_types().iter().enumerate() { let param_name = format!("var{}", base_index + i); @@ -288,19 +306,19 @@ fn instantiate_doc_function_with_context( new_params.push(( origin_param.0.clone(), Some(LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), + VariadicType::Base(resolved_type).into(), )), )); } - SubstitutorValue::Params(params) => { - for param in params.iter() { - new_params.push(param.clone()); + TypeMapperValue::Params(params) => { + for param in params { + new_params.push(param); } } - SubstitutorValue::MultiTypes(types) => { - for (i, typ) in types.iter().enumerate() { + TypeMapperValue::MultiTypes(values) => { + for (i, value) in values.into_iter().enumerate() { let param_name = format!("var{}", i); - new_params.push((param_name, Some(typ.clone()))); + new_params.push((param_name, Some(value))); } } _ => { @@ -318,7 +336,7 @@ fn instantiate_doc_function_with_context( } } LuaType::Generic(generic) => { - let new_type = instantiate_generic_with_context(context, generic); + let new_type = instantiate_generic(context, frame, generic); // 如果是 rest 参数且实例化后的类型是 tuple, 那么我们将展开 tuple if let LuaType::Tuple(tuple_type) = &new_type { let base_index = new_params.len(); @@ -336,13 +354,13 @@ fn instantiate_doc_function_with_context( VariadicType::Multi(_) => (), }, _ => { - let new_type = instantiate_type_generic_with_context(context, origin_param_type); + let new_type = instantiate_type_generic_inner(context, frame, origin_param_type); new_params.push((origin_param.0.clone(), Some(new_type))); } } } - let mut inst_ret_type = instantiate_type_generic_with_context(context, tpl_ret); + let mut inst_ret_type = instantiate_type_generic_inner(context, frame, tpl_ret); // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple if let LuaType::Variadic(_) = &&tpl_ret && let LuaType::Tuple(tuple) = &inst_ret_type @@ -380,52 +398,69 @@ fn instantiate_doc_function_with_context( ) } -fn instantiate_object(context: &GenericInstantiateContext, object: &LuaObjectType) -> LuaType { +fn instantiate_object( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + object: &LuaObjectType, +) -> LuaType { + if !object.any_type(requires_instantiation_walk) { + return LuaType::Object(object.clone().into()); + } + let new_fields = object .get_fields() .iter() .map(|(key, field)| { ( key.clone(), - instantiate_type_generic_with_context(context, field), + instantiate_type_generic_inner(context, frame, field), ) }) .collect::>(); - let new_index_access = instantiate_type_pairs(context, object.get_index_access().iter()); + let new_index_access = instantiate_type_pairs(context, frame, object.get_index_access().iter()); LuaType::Object(LuaObjectType::new_with_fields(new_fields, new_index_access).into()) } -fn instantiate_union(context: &GenericInstantiateContext, union: &LuaUnionType) -> LuaType { - LuaType::from_vec(instantiate_types(context, union.into_vec().iter())) +fn instantiate_union( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + union: &LuaUnionType, +) -> LuaType { + if !union.any_type(requires_instantiation_walk) { + return LuaType::Union(union.clone().into()); + } + + LuaType::from_vec(instantiate_types(context, frame, union.into_vec().iter())) } fn instantiate_intersection( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, intersection: &LuaIntersectionType, ) -> LuaType { + if !intersection.any_type(requires_instantiation_walk) { + return LuaType::Intersection(intersection.clone().into()); + } + LuaType::Intersection( - LuaIntersectionType::new(instantiate_types(context, intersection.get_types().iter())) - .into(), + LuaIntersectionType::new(instantiate_types( + context, + frame, + intersection.get_types().iter(), + )) + .into(), ) } -pub fn instantiate_generic( - db: &DbIndex, - generic: &LuaGenericType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_generic_with_context(&context, generic) -} - -fn instantiate_generic_with_context( +fn instantiate_generic( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, generic: &LuaGenericType, ) -> LuaType { let generic_params = generic.get_params(); - let new_params = instantiate_types(context, generic_params.iter()); + let new_params = instantiate_types(context, frame, generic_params.iter()); let base = generic.get_base_type(); let type_decl_id = if let LuaType::Ref(id) = base { @@ -434,13 +469,17 @@ fn instantiate_generic_with_context( return LuaType::Unknown; }; - if !context.substitutor.check_recursion(&type_decl_id) - && let Some(type_decl) = context.db.get_type_index().get_type_decl(&type_decl_id) + if let Some(type_decl) = context.db.get_type_index().get_type_decl(&type_decl_id) && type_decl.is_alias() { - let new_substitutor = TypeSubstitutor::from_alias(new_params.clone(), type_decl_id.clone()); - if let Some(origin) = type_decl.get_alias_origin(context.db, Some(&new_substitutor)) { - return origin; + let Some(alias_stack) = context.enter_alias_stack(&type_decl_id) else { + return LuaType::Generic(LuaGenericType::new(type_decl_id, new_params).into()); + }; + let alias_mapper = TypeMapper::from_alias(context.db, new_params.clone(), &type_decl_id); + let alias_mapper = TypeMapper::merge(Some(alias_mapper), context.mapper.clone()); + let alias_context = context.with_alias_stack(alias_stack, alias_mapper); + if let Some(origin) = type_decl.get_alias_ref() { + return instantiate_type_generic_inner(&alias_context, frame, origin); } } @@ -449,76 +488,93 @@ fn instantiate_generic_with_context( fn instantiate_table_generic( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, table_params: &[LuaType], ) -> LuaType { - LuaType::TableGeneric(instantiate_types(context, table_params.iter()).into()) + if !table_params.iter().any(requires_instantiation_walk) { + return LuaType::TableGeneric(table_params.to_vec().into()); + } + + LuaType::TableGeneric(instantiate_types(context, frame, table_params.iter()).into()) } fn instantiate_uninferred_tpl_fallback( tpl: &GenericTpl, context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, ) -> LuaType { // 一些情况下需要保留 TplRef, 例如高阶函数调用 - if context.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { + if frame.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { return LuaType::TplRef(tpl.clone().into()); } // 显式默认值优先, 然后是 extends 约束, 最后才是 unknown. if let Some(default_type) = tpl.get_default_type() { - return instantiate_type_generic_with_context(context, default_type); + return instantiate_type_generic_inner(context, frame, default_type); } if let Some(constraint) = tpl.get_constraint() { - return instantiate_type_generic_with_context(context, constraint); + return instantiate_type_generic_inner(context, frame, constraint); } LuaType::Unknown } -fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { +fn instantiate_tpl_ref( + tpl: &GenericTpl, + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, +) -> LuaType { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); + TypeMapperValue::None => { + return instantiate_uninferred_tpl_fallback(tpl, context, frame); } - SubstitutorValue::Type(ty) => return ty.default().clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + TypeMapperValue::Type(value) => { + return value; } - SubstitutorValue::Params(params) => { + TypeMapperValue::MultiTypes(values) => { + return LuaType::Variadic(VariadicType::Multi(values).into()); + } + TypeMapperValue::Params(params) => { return params .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() + .and_then(|(_, ty)| ty.clone()) .unwrap_or(LuaType::Unknown); } - SubstitutorValue::MultiBase(base) => return base.clone(), + TypeMapperValue::MultiBase(base) => return base, } } LuaType::TplRef(tpl.clone().into()) } -fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { +fn instantiate_const_tpl_ref( + tpl: &GenericTpl, + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, +) -> LuaType { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); + TypeMapperValue::None => { + if frame.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { + return LuaType::ConstTplRef(tpl.clone().into()); + } + return instantiate_uninferred_tpl_fallback(tpl, context, frame); + } + TypeMapperValue::Type(value) => { + return value; } - SubstitutorValue::Type(ty) => return ty.raw().clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + TypeMapperValue::MultiTypes(values) => { + return LuaType::Variadic(VariadicType::Multi(values).into()); } - SubstitutorValue::Params(params) => { + TypeMapperValue::Params(params) => { return params .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() + .and_then(|(_, ty)| ty.clone()) .unwrap_or(LuaType::Unknown); } - SubstitutorValue::MultiBase(base) => return base.clone(), + TypeMapperValue::MultiBase(base) => return base, } } @@ -527,20 +583,22 @@ fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateConte fn instantiate_signature( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, signature_id: &LuaSignatureId, ) -> LuaType { if let Some(signature) = context.db.get_signature_index().get(signature_id) { let origin_type = { let fake_doc_function = signature.to_doc_func_type(); - instantiate_doc_function_with_context(context, &fake_doc_function) + instantiate_doc_function(context, frame, &fake_doc_function) }; if signature.overloads.is_empty() { return origin_type; } else { let mut result = Vec::new(); for overload in signature.overloads.iter() { - result.push(instantiate_doc_function_with_context( + result.push(instantiate_doc_function( context, + frame, &(*overload).clone(), )); } @@ -554,45 +612,47 @@ fn instantiate_signature( fn instantiate_variadic_type( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, variadic: &VariadicType, ) -> LuaType { + if !variadic.any_type(requires_instantiation_walk) { + return LuaType::Variadic(variadic.clone().into()); + } + match variadic { VariadicType::Base(base) => match base { LuaType::TplRef(tpl) => { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { + if let Some(value) = get_mapped_value(tpl.get_tpl_id(), &context.mapper) { match value { - SubstitutorValue::None => { - let fallback = instantiate_uninferred_tpl_fallback(tpl, context); + TypeMapperValue::None => { + let fallback = instantiate_uninferred_tpl_fallback(tpl, context, frame); return match fallback { LuaType::Variadic(_) | LuaType::Never => fallback, LuaType::Nil | LuaType::Any | LuaType::Unknown => fallback, _ => LuaType::Variadic(VariadicType::Base(fallback).into()), }; } - SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + TypeMapperValue::Type(resolved_type) => { if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never ) { - return resolved_type.clone(); + return resolved_type; } - return LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), - ); + return LuaType::Variadic(VariadicType::Base(resolved_type).into()); } - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + TypeMapperValue::MultiTypes(values) => { + return LuaType::Variadic(VariadicType::Multi(values).into()); } - SubstitutorValue::Params(params) => { + TypeMapperValue::Params(params) => { let types = params - .iter() - .filter_map(|(_, ty)| ty.clone()) + .into_iter() + .filter_map(|(_, ty)| ty) .collect::>(); return LuaType::Variadic(VariadicType::Multi(types).into()); } - SubstitutorValue::MultiBase(base) => { - return LuaType::Variadic(VariadicType::Base(base.clone()).into()); + TypeMapperValue::MultiBase(base) => { + return LuaType::Variadic(VariadicType::Base(base).into()); } } } else { @@ -600,7 +660,7 @@ fn instantiate_variadic_type( } } LuaType::Generic(generic) => { - return instantiate_generic_with_context(context, generic); + return instantiate_generic(context, frame, generic); } _ => {} }, @@ -608,7 +668,7 @@ fn instantiate_variadic_type( if types.iter().any(LuaTypeNode::contains_tpl_node) { let mut new_types = Vec::new(); for t in types { - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, frame, t); match t { LuaType::Never => {} LuaType::Variadic(variadic) => match variadic.deref() { @@ -630,92 +690,6 @@ fn instantiate_variadic_type( LuaType::Variadic(variadic.clone().into()) } -fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMappedType) -> LuaType { - let constraint = mapped - .param - .1 - .type_constraint - .as_ref() - .map(|ty| instantiate_type_generic_with_context(context, ty)); - - if let Some(constraint) = constraint { - let mut key_types = Vec::new(); - collect_mapped_key_atoms(&constraint, &mut key_types); - - let mut visited = HashSet::new(); - let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::new(); - let mut index_access: Vec<(LuaType, LuaType)> = Vec::new(); - - for key_ty in key_types { - if !visited.insert(key_ty.clone()) { - continue; - } - - let value_ty = instantiate_mapped_value(context, mapped, mapped.param.0, &key_ty); - - if let Some(member_key) = key_type_to_member_key(&key_ty) { - if let Some((_, existing)) = fields.iter_mut().find(|(key, _)| key == &member_key) { - let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); - *existing = merged; - } else { - fields.push((member_key, value_ty)); - } - } else { - index_access.push((key_ty, value_ty)); - } - } - - if !fields.is_empty() || !index_access.is_empty() { - // key 从 0 开始递增才被视为元组 - if constraint.is_tuple() { - let mut index = 0; - let mut is_tuple = true; - for (key, _) in &fields { - if let LuaMemberKey::Integer(i) = key { - if *i != index { - is_tuple = false; - break; - } - index += 1; - } else { - is_tuple = false; - break; - } - } - if is_tuple { - let types = fields.into_iter().map(|(_, ty)| ty).collect(); - return LuaType::Tuple( - LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(), - ); - } - } - let field_map: HashMap = fields.into_iter().collect(); - return LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()); - } - } - - instantiate_type_generic_with_context(context, &mapped.value) -} - -fn instantiate_mapped_value( - context: &GenericInstantiateContext, - mapped: &LuaMappedType, - tpl_id: GenericTplId, - replacement: &LuaType, -) -> LuaType { - let mut local_substitutor = context.substitutor.clone(); - local_substitutor.insert_type(tpl_id, replacement.clone(), true); - let local_context = context.with_substitutor(&local_substitutor); - let mut result = instantiate_type_generic_with_context(&local_context, &mapped.value); - // 根据 readonly 和 optional 属性进行处理 - if mapped.is_optional { - result = TypeOps::Union.apply(context.db, &result, &LuaType::Nil); - } - // TODO: 处理 readonly, 但目前 readonly 的实现存在问题, 这里我们先跳过 - - result -} - pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { match key_ty { LuaType::DocStringConst(s) => Some(LuaMemberKey::Name(s.deref().clone())), @@ -726,36 +700,6 @@ pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { } } -fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) { - match key_ty { - LuaType::Union(union) => { - for member in union.into_vec() { - collect_mapped_key_atoms(&member, acc); - } - } - LuaType::MultiLineUnion(multi) => { - for (member, _) in multi.get_unions() { - collect_mapped_key_atoms(member, acc); - } - } - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), - VariadicType::Multi(types) => { - for member in types { - collect_mapped_key_atoms(member, acc); - } - } - }, - LuaType::Tuple(tuple) => { - for member in tuple.get_types() { - collect_mapped_key_atoms(member, acc); - } - } - LuaType::Unknown | LuaType::Never => {} - _ => acc.push(key_ty.clone()), - } -} - pub(super) fn get_default_constructor(db: &DbIndex, decl_id: &LuaTypeDeclId) -> Option { let ids = db .get_operator_index() @@ -765,3 +709,36 @@ pub(super) fn get_default_constructor(db: &DbIndex, decl_id: &LuaTypeDeclId) -> let operator = db.get_operator_index().get_operator(id)?; Some(operator.get_operator_func(db)) } + +fn is_simple_instantiate_leaf(ty: &LuaType) -> bool { + matches!( + ty, + LuaType::Unknown + | LuaType::Any + | LuaType::Nil + | LuaType::Table + | LuaType::Userdata + | LuaType::Function + | LuaType::Thread + | LuaType::Boolean + | LuaType::String + | LuaType::Integer + | LuaType::Number + | LuaType::Io + | LuaType::Global + | LuaType::Never + | LuaType::BooleanConst(_) + | LuaType::StringConst(_) + | LuaType::IntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::TableConst(_) + | LuaType::Ref(_) + | LuaType::Def(_) + | LuaType::DocStringConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::DocBooleanConst(_) + | LuaType::Namespace(_) + | LuaType::Language(_) + | LuaType::ModuleRef(_) + ) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 90e34baa3..a617ae2f0 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -1,9 +1,10 @@ mod call_constraint; +mod inference; mod instantiate_type; mod test; -mod tpl_context; -mod tpl_pattern; -mod type_substitutor; +mod type_mapper; + +use std::sync::Arc; pub use call_constraint::{ CallConstraintArg, CallConstraintContext, build_call_constraint_context, @@ -11,27 +12,222 @@ pub use call_constraint::{ }; use emmylua_parser::LuaAstNode; use emmylua_parser::LuaExpr; -pub(crate) use instantiate_type::collect_callable_overload_groups; +use hashbrown::HashSet; +pub(in crate::semantic::generic) use inference::{ + InferenceContext, InferencePriority, InferenceVariance, infer_type_list, infer_types_from_expr, + multi_param_infer_multi_return, return_type_infer_types, variadic_infer_types, +}; pub use instantiate_type::*; use rowan::NodeOrToken; -pub use tpl_context::TplContext; -pub use tpl_pattern::tpl_pattern_match_args; -pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; -pub use type_substitutor::TypeSubstitutor; +pub(in crate::semantic::generic) use type_mapper::get_mapped_value; +pub use type_mapper::{TypeMapper, TypeMapperValue}; use crate::DbIndex; +use crate::GenericTpl; use crate::GenericTplId; +use crate::InferFailReason; use crate::LuaDeclExtra; +use crate::LuaFunctionType; use crate::LuaInferCache; use crate::LuaMemberOwner; use crate::LuaSemanticDeclId; use crate::LuaType; +use crate::LuaTypeNode; use crate::SemanticDeclLevel; use crate::TypeOps; use crate::infer_node_semantic_decl; use crate::semantic::semantic_info::infer_token_semantic_decl; pub use instantiate_type::get_keyof_members; +pub fn instantiate_doc_function_by_arg_types( + db: &DbIndex, + cache: &mut LuaInferCache, + doc_function: &Arc, + call_arg_types: &[LuaType], +) -> Result, InferFailReason> { + let generic_tpl_ids = collect_doc_function_tpl_ids(doc_function); + if generic_tpl_ids.is_empty() { + return Ok(doc_function.clone()); + } + + let param_types = doc_function + .get_params() + .iter() + .map(|(_, typ)| typ.clone().unwrap_or(LuaType::Unknown)) + .collect::>(); + let mut context = InferenceContext::new(db, cache, None); + context.prepare_inference_slots(generic_tpl_ids); + infer_type_list( + &mut context, + ¶m_types, + call_arg_types, + &LuaType::Unknown, + InferenceVariance::Covariant, + InferencePriority::Normal, + )?; + + let generic_tpls = collect_doc_function_generic_tpls(doc_function); + let mapper = context.fixing_mapper(generic_tpls.iter(), doc_function.get_ret()); + + let doc_function_ty = LuaType::DocFunction(doc_function.clone()); + Ok( + match instantiate_type_generic(db, &doc_function_ty, &mapper) { + LuaType::DocFunction(func) => func, + _ => doc_function.clone(), + }, + ) +} + +fn collect_doc_function_tpl_ids(doc_function: &LuaFunctionType) -> HashSet { + let mut generic_tpl_ids = HashSet::new(); + doc_function.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + collect_function_tpl_with_fallback_deps(&generic_tpl, &mut generic_tpl_ids); + } + LuaType::StrTplRef(str_tpl) => { + let tpl_id = str_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return; + } + + generic_tpl_ids.insert(tpl_id); + let Some(constraint) = str_tpl.get_constraint() else { + return; + }; + + let mut constraint_deps = HashSet::new(); + if collect_function_tpl_deps_from_fallback_type( + constraint, + &mut constraint_deps, + &mut HashSet::new(), + ) { + generic_tpl_ids.extend(constraint_deps); + } + } + _ => {} + }); + + generic_tpl_ids +} + +fn collect_doc_function_generic_tpls(doc_function: &LuaFunctionType) -> Vec> { + let mut generic_tpls = Vec::new(); + doc_function.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + if generic_tpl.get_tpl_id().is_func() + && !generic_tpls.iter().any(|existing: &Arc| { + existing.get_tpl_id() == generic_tpl.get_tpl_id() + }) + { + generic_tpls.push(generic_tpl.clone()); + } + } + _ => {} + }); + + generic_tpls +} + +fn collect_function_tpl_with_fallback_deps( + generic_tpl: &GenericTpl, + generic_tpl_ids: &mut HashSet, +) { + let tpl_id = generic_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return; + } + + generic_tpl_ids.insert(tpl_id); + let Some(fallback_type) = generic_tpl + .get_default_type() + .or(generic_tpl.get_constraint()) + else { + return; + }; + + let mut fallback_deps = HashSet::new(); + let mut visiting_fallbacks = HashSet::new(); + visiting_fallbacks.insert(tpl_id); + if collect_function_tpl_deps_from_fallback_type( + fallback_type, + &mut fallback_deps, + &mut visiting_fallbacks, + ) { + generic_tpl_ids.extend(fallback_deps); + } +} + +fn collect_function_tpl_deps_from_fallback_type( + ty: &LuaType, + generic_tpl_ids: &mut HashSet, + visiting_fallbacks: &mut HashSet, +) -> bool { + let mut no_fallback_cycle = + collect_function_tpl_dep_from_fallback_type(ty, generic_tpl_ids, visiting_fallbacks); + ty.visit_nested_types(&mut |ty| { + no_fallback_cycle &= + collect_function_tpl_dep_from_fallback_type(ty, generic_tpl_ids, visiting_fallbacks); + }); + no_fallback_cycle +} + +fn collect_function_tpl_dep_from_fallback_type( + ty: &LuaType, + generic_tpl_ids: &mut HashSet, + visiting_fallbacks: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + let tpl_id = generic_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return true; + } + + if !visiting_fallbacks.insert(tpl_id) { + return false; + } + + generic_tpl_ids.insert(tpl_id); + let no_fallback_cycle = match generic_tpl + .get_default_type() + .or(generic_tpl.get_constraint()) + { + Some(fallback_type) => collect_function_tpl_deps_from_fallback_type( + fallback_type, + generic_tpl_ids, + visiting_fallbacks, + ), + None => true, + }; + visiting_fallbacks.remove(&tpl_id); + no_fallback_cycle + } + LuaType::StrTplRef(str_tpl) => { + let tpl_id = str_tpl.get_tpl_id(); + if !tpl_id.is_func() { + return true; + } + + if !visiting_fallbacks.insert(tpl_id) { + return false; + } + + generic_tpl_ids.insert(tpl_id); + let no_fallback_cycle = match str_tpl.get_constraint() { + Some(constraint) => collect_function_tpl_deps_from_fallback_type( + constraint, + generic_tpl_ids, + visiting_fallbacks, + ), + None => true, + }; + visiting_fallbacks.remove(&tpl_id); + no_fallback_cycle + } + _ => true, + } +} + pub fn get_tpl_ref_extend_type( db: &DbIndex, cache: &mut LuaInferCache, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 21dee2f3f..967ab7103 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -1,6 +1,28 @@ #[cfg(test)] mod test { - use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; + use hashbrown::HashMap; + use std::sync::Arc; + + use super::super::instantiate_type::{ + TplResolvePolicy, instantiate_type_generic_full, regularize_tpl_candidate_type, + widen_tpl_candidate_type, + }; + use crate::{ + AsyncState, DbIndex, DiagnosticCode, GenericTpl, GenericTplId, LuaArrayType, + LuaFunctionType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaTupleStatus, + LuaTupleType, LuaType, LuaUnionType, TypeMapper, TypeMapperValue, VariadicType, + VirtualWorkspace, + }; + use smol_str::SmolStr; + + fn func_tpl(idx: u32, name: &str) -> Arc { + Arc::new(GenericTpl::new( + GenericTplId::Func(idx), + SmolStr::new(name).into(), + None, + None, + )) + } #[test] fn test_variadic_func() { @@ -298,4 +320,538 @@ result = { "# )); } + + #[test] + fn test_inference_mapper_fallback_and_explicit_precedence() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T = string + ---@return T + local function defaulted() + end + + ---@generic T: integer + ---@return T + local function constrained() + end + + ---@generic T + ---@param value T + ---@return T + local function explicit(value) + end + + default_result = defaulted() + constraint_result = constrained() + explicit_result = explicit--[[@]](1) + + "#, + ); + + assert_eq!(ws.expr_ty("default_result"), ws.ty("string")); + assert_eq!(ws.expr_ty("constraint_result"), ws.ty("integer")); + assert_eq!(ws.expr_ty("explicit_result"), ws.ty("string")); + } + + #[test] + fn test_mapper_reducer_reuses_alias_mapped_and_function_shapes() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias MapperBox { value: T, list: T[] } + ---@alias Copy { [K in keyof T]: T[K]; } + + ---@generic T + ---@param value T + ---@return MapperBox + local function box(value) + end + + ---@generic T + ---@param value T + ---@return Copy + local function copy(value) + end + + ---@generic T + ---@param value T + ---@return fun(next: T): T + local function make_id(value) + end + + box_result = box("name") + box_value = box_result.value + box_list_item = box_result.list[1] + + copied = copy({ name = "a", count = 1 }) + copied_name = copied.name + copied_count = copied.count + + made = make_id(1) + made_ret = made(2) + "#, + ); + + assert_eq!(ws.expr_ty("box_value"), ws.ty("string")); + assert_eq!(ws.expr_ty("box_list_item"), ws.ty("string?")); + assert_eq!(ws.expr_ty("copied_name"), ws.ty("string")); + assert_eq!(ws.expr_ty("copied_count"), ws.ty("integer")); + assert_eq!(ws.expr_ty("made_ret"), ws.ty("integer")); + } + + #[test] + fn test_structural_instantiate_fast_path_preserves_plain_shapes() { + let db = DbIndex::new(); + let empty_mapper = TypeMapper::empty(); + + let plain_array = LuaType::Array(LuaArrayType::from_base_type(LuaType::Number).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_array, + &empty_mapper + ), + plain_array + ); + + let plain_tuple = LuaType::Tuple( + LuaTupleType::new( + vec![LuaType::Number, LuaType::String], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_tuple, + &empty_mapper + ), + plain_tuple + ); + + let plain_object = LuaType::Object( + LuaObjectType::new_with_fields( + HashMap::from([ + (LuaMemberKey::Name(SmolStr::new("name")), LuaType::String), + (LuaMemberKey::Name(SmolStr::new("count")), LuaType::Number), + ]), + Vec::new(), + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_object, + &empty_mapper + ), + plain_object + ); + + let plain_union = + LuaType::Union(LuaUnionType::from_vec(vec![LuaType::Number, LuaType::String]).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_union, + &empty_mapper + ), + plain_union + ); + + let plain_intersection = LuaType::Intersection( + LuaIntersectionType::new(vec![LuaType::Number, LuaType::String]).into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_intersection, + &empty_mapper + ), + plain_intersection + ); + + let plain_table_generic = + LuaType::TableGeneric(Arc::new(vec![LuaType::Number, LuaType::String])); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_table_generic, + &empty_mapper + ), + plain_table_generic + ); + + let plain_variadic = LuaType::Variadic(VariadicType::Base(LuaType::Number).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_variadic, + &empty_mapper + ), + plain_variadic + ); + + let plain_doc_function = LuaType::DocFunction( + LuaFunctionType::new( + AsyncState::None, + false, + false, + vec![("value".to_string(), Some(LuaType::Number))], + LuaType::String, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_doc_function, + &empty_mapper + ), + plain_doc_function + ); + + let plain_type_guard = LuaType::TypeGuard(Arc::new(LuaType::Number)); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &plain_type_guard, + &empty_mapper + ), + plain_type_guard + ); + } + + #[test] + fn test_structural_instantiate_fast_path_instantiates_template_children() { + let db = DbIndex::new(); + let mapper = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let tpl = LuaType::TplRef(Arc::new(GenericTpl::new( + GenericTplId::Func(0), + SmolStr::new("T0").into(), + None, + None, + ))); + + let templated_array = LuaType::Array(LuaArrayType::from_base_type(tpl.clone()).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_array, + &mapper + ), + LuaType::Array(LuaArrayType::from_base_type(LuaType::String).into()) + ); + + let templated_tuple = LuaType::Tuple( + LuaTupleType::new( + vec![LuaType::Number, tpl.clone()], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_tuple, + &mapper + ), + LuaType::Tuple( + LuaTupleType::new( + vec![LuaType::Number, LuaType::String], + LuaTupleStatus::DocResolve, + ) + .into() + ) + ); + + let templated_object = LuaType::Object( + LuaObjectType::new_with_fields( + HashMap::from([ + (LuaMemberKey::Name(SmolStr::new("name")), tpl.clone()), + (LuaMemberKey::Name(SmolStr::new("count")), LuaType::Number), + ]), + Vec::new(), + ) + .into(), + ); + let expected_object = LuaType::Object( + LuaObjectType::new_with_fields( + HashMap::from([ + (LuaMemberKey::Name(SmolStr::new("name")), LuaType::String), + (LuaMemberKey::Name(SmolStr::new("count")), LuaType::Number), + ]), + Vec::new(), + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_object, + &mapper + ), + expected_object + ); + + let templated_union = + LuaType::Union(LuaUnionType::from_vec(vec![LuaType::Number, tpl.clone()]).into()); + let expected_union = + LuaType::Union(LuaUnionType::from_vec(vec![LuaType::Number, LuaType::String]).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_union, + &mapper + ), + expected_union + ); + + let templated_intersection = LuaType::Intersection( + LuaIntersectionType::new(vec![LuaType::Number, tpl.clone()]).into(), + ); + let expected_intersection = LuaType::Intersection( + LuaIntersectionType::new(vec![LuaType::Number, LuaType::String]).into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_intersection, + &mapper + ), + expected_intersection + ); + + let templated_table_generic = LuaType::TableGeneric(Arc::new(vec![tpl.clone()])); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_table_generic, + &mapper + ), + LuaType::TableGeneric(Arc::new(vec![LuaType::String])) + ); + + let templated_variadic = LuaType::Variadic(VariadicType::Base(tpl.clone()).into()); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_variadic, + &mapper + ), + LuaType::Variadic(VariadicType::Base(LuaType::String).into()) + ); + + let templated_doc_function = LuaType::DocFunction( + LuaFunctionType::new( + AsyncState::None, + false, + false, + vec![("value".to_string(), Some(tpl.clone()))], + tpl.clone(), + ) + .into(), + ); + let expected_doc_function = LuaType::DocFunction( + LuaFunctionType::new( + AsyncState::None, + false, + false, + vec![("value".to_string(), Some(LuaType::String))], + LuaType::String, + ) + .into(), + ); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_doc_function, + &mapper + ), + expected_doc_function + ); + + let templated_type_guard = LuaType::TypeGuard(Arc::new(tpl.clone())); + assert_eq!( + super::super::instantiate_type::instantiate_type_generic( + &db, + &templated_type_guard, + &mapper + ), + LuaType::TypeGuard(Arc::new(LuaType::String)) + ); + } + + #[test] + fn test_preserve_uninferred_keeps_unmapped_tpl_ref_and_const_tpl_ref() { + let db = DbIndex::new(); + let mapper = TypeMapper::from_uninferred(vec![GenericTplId::Func(0)]); + let tpl = func_tpl(0, "T0"); + + let tpl_ref = LuaType::TplRef(tpl.clone()); + assert_eq!( + instantiate_type_generic_full( + &db, + &tpl_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ), + tpl_ref + ); + + let const_tpl_ref = LuaType::ConstTplRef(tpl.clone()); + assert_eq!( + instantiate_type_generic_full( + &db, + &const_tpl_ref, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ), + const_tpl_ref + ); + } + + #[test] + fn test_preserve_uninferred_applies_concrete_values_while_retaining_residual_templates() { + let db = DbIndex::new(); + let concrete = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let unresolved = + TypeMapper::from_uninferred(vec![GenericTplId::Func(1), GenericTplId::Func(2)]); + let mapper = TypeMapper::merge(Some(concrete), unresolved); + let t0 = func_tpl(0, "T0"); + let t1 = func_tpl(1, "T1"); + let t2 = func_tpl(2, "T2"); + + // 这段输入可以理解成 Lua 伪代码: + // ---@type [T0, T1, const T2] + // local value + // 套用 mapper 后,T0 会被具体化为 string, + // 但 T1 / T2 仍然要保留为残余模板,不能被过早折叠掉。 + let ty = LuaType::Tuple( + LuaTupleType::new( + vec![ + LuaType::TplRef(t0), + LuaType::TplRef(t1.clone()), + LuaType::ConstTplRef(t2.clone()), + ], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + let preserved = instantiate_type_generic_full( + &db, + &ty, + &mapper, + None, + TplResolvePolicy::PreserveTplRef, + ); + let expected = LuaType::Tuple( + LuaTupleType::new( + vec![ + LuaType::String, + LuaType::TplRef(t1), + LuaType::ConstTplRef(t2), + ], + LuaTupleStatus::DocResolve, + ) + .into(), + ); + + assert_eq!(preserved, expected); + } + + #[test] + fn test_123() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param x T + ---@return T + function f(x) + return x + end + + A = f("hello") + B = f({value = "hello"}) + C = B.value + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(a_ty), "\"hello\""); + + let b_ty = ws.expr_ty("B"); + let b_desc = ws.humanize_type_detailed(b_ty); + assert!( + b_desc.contains("value: string"), + "unexpected type: {}", + b_desc + ); + + let c_ty = ws.expr_ty("C"); + assert_eq!(ws.humanize_type(c_ty), "string"); + } + + #[test] + fn test_regularize_tpl_candidate_type_preserves_root_primitive_and_widens_nested_literals() { + let mut ws = VirtualWorkspace::new(); + + let root_literal = ws.ty(r#""mode""#); + let regularized_root = { + let db = ws.analysis.compilation.get_db(); + regularize_tpl_candidate_type(db, root_literal.clone()) + }; + assert_eq!(regularized_root, root_literal); + + let table = ws.expr_ty(r#"{ kind = "mode", count = 1 }"#); + let regularized_table = { + let db = ws.analysis.compilation.get_db(); + regularize_tpl_candidate_type(db, table) + }; + assert_eq!(regularized_table, ws.ty("{ kind: string, count: integer }")); + } + + #[test] + fn test_widen_tpl_candidate_type_widens_root_primitive_and_structural_literals() { + let mut ws = VirtualWorkspace::new(); + + let root_literal = ws.ty(r#""mode""#); + let widened_root = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, root_literal) + }; + assert_eq!(widened_root, LuaType::String); + + let root_union = ws.ty(r#""left" | "right""#); + let widened_root_union = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, root_union.clone()) + }; + assert_eq!(widened_root_union, root_union); + + let tuple = ws.expr_ty(r#"{ "mode", 1 }"#); + let widened_tuple = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, tuple) + }; + assert_eq!(ws.humanize_type(widened_tuple), "(string,integer)"); + + let table = ws.expr_ty(r#"{ kind = "mode", count = 1 }"#); + let widened_table = { + let db = ws.analysis.compilation.get_db(); + widen_tpl_candidate_type(db, table) + }; + assert_eq!(widened_table, ws.ty("{ kind: string, count: integer }")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs deleted file mode 100644 index 02200a4a6..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs +++ /dev/null @@ -1,11 +0,0 @@ -use emmylua_parser::LuaCallExpr; - -use crate::{DbIndex, LuaInferCache, TypeSubstitutor}; - -#[derive(Debug)] -pub struct TplContext<'a> { - pub db: &'a DbIndex, - pub cache: &'a mut LuaInferCache, - pub substitutor: &'a mut TypeSubstitutor, - pub call_expr: Option, -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs deleted file mode 100644 index aa556ca88..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ /dev/null @@ -1,136 +0,0 @@ -use crate::{ - InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_generic, instantiate_type_generic, - semantic::generic::tpl_pattern::{ - TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, - }, -}; - -pub fn generic_tpl_pattern_match( - context: &mut TplContext, - generic: &LuaGenericType, - target: &LuaType, -) -> TplPatternMatchResult { - generic_tpl_pattern_match_inner(context, generic, target, &InferGuard::new()) -} - -fn generic_tpl_pattern_match_inner( - context: &mut TplContext, - source_generic: &LuaGenericType, - target: &LuaType, - infer_guard: &InferGuardRef, -) -> TplPatternMatchResult { - match target { - LuaType::Generic(target_generic) => { - let base = source_generic.get_base_type_id_ref(); - let target_base = target_generic.get_base_type_id_ref(); - if base == target_base { - let params = source_generic.get_params(); - let target_params = target_generic.get_params(); - let min_len = params.len().min(target_params.len()); - for i in 0..min_len { - match (¶ms[i], &target_params[i]) { - (LuaType::Variadic(variadict), _) => { - variadic_tpl_pattern_match(context, variadict, &target_params[i..])?; - break; - } - _ => { - tpl_pattern_match(context, ¶ms[i], &target_params[i])?; - } - } - } - return Ok(()); - } - - let target_decl = context - .db - .get_type_index() - .get_type_decl(target_base) - .ok_or(InferFailReason::None)?; - if target_decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - target_generic.get_params().clone(), - target_base.clone(), - ); - if let Some(origin_type) = - target_decl.get_alias_origin(context.db, Some(&substitutor)) - { - return generic_tpl_pattern_match_inner( - context, - source_generic, - &origin_type, - infer_guard, - ); - } - } else if let Some(super_types) = - context.db.get_type_index().get_super_types(target_base) - { - for mut super_type in super_types { - if super_type.contains_tpl_node() { - let substitutor = - TypeSubstitutor::from_type_array(target_generic.get_params().clone()); - super_type = - instantiate_type_generic(context.db, &super_type, &substitutor); - } - - generic_tpl_pattern_match_inner( - context, - source_generic, - &super_type, - &infer_guard.fork(), - )?; - } - } - } - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - infer_guard.check(type_id)?; - let type_decl = context - .db - .get_type_index() - .get_type_decl(type_id) - .ok_or(InferFailReason::None)?; - if let Some(origin_type) = type_decl.get_alias_origin(context.db, None) { - return generic_tpl_pattern_match_inner( - context, - source_generic, - &origin_type, - infer_guard, - ); - } - - for super_type in context - .db - .get_type_index() - .get_super_types(type_id) - .unwrap_or_default() - { - generic_tpl_pattern_match_inner( - context, - source_generic, - &super_type, - &infer_guard.fork(), - )?; - } - } - LuaType::Union(union_type) => { - for union_sub_type in &union_type.into_vec() { - generic_tpl_pattern_match_inner( - context, - source_generic, - union_sub_type, - &infer_guard.fork(), - )?; - } - } - _ => { - // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 - let substitutor = TypeSubstitutor::new(); - let typ = instantiate_generic(context.db, source_generic, &substitutor); - if LuaType::from(source_generic.clone()) != typ { - tpl_pattern_match(context, &typ, target)?; - } - } - } - - Ok(()) -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs deleted file mode 100644 index 0f8751ea3..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::{ - InferFailReason, LuaSignatureId, LuaType, TplContext, infer_expr, - semantic::generic::tpl_pattern::TplPatternMatchResult, -}; - -pub fn check_lambda_tpl_pattern( - context: &mut TplContext, - signature_id: LuaSignatureId, -) -> TplPatternMatchResult { - let call_expr = context.call_expr.clone().ok_or(InferFailReason::None)?; - let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; - for arg in call_arg_list.get_args() { - if let Ok(LuaType::Signature(arg_signature_id)) = - infer_expr(context.db, context.cache, arg.clone()) - && arg_signature_id == signature_id - { - return Ok(()); - } - } - - Err(InferFailReason::UnResolveSignatureReturn(signature_id)) -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs deleted file mode 100644 index 77900f2d4..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ /dev/null @@ -1,1044 +0,0 @@ -mod generic_tpl_pattern; -mod lambda_tpl_pattern; - -use std::{ops::Deref, sync::Arc}; - -use emmylua_parser::LuaAstNode; -use itertools::Itertools; -use rowan::NodeOrToken; -use smol_str::SmolStr; - -use crate::{ - InferFailReason, LuaFunctionType, LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaObjectType, - LuaSemanticDeclId, LuaTupleType, LuaTypeDeclId, LuaTypeNode, LuaUnionType, SemanticDeclLevel, - VariadicType, check_type_compact, - db_index::{DbIndex, LuaGenericType, LuaType}, - infer_node_semantic_decl, - semantic::{ - generic::{ - tpl_context::TplContext, tpl_pattern::generic_tpl_pattern::generic_tpl_pattern_match, - type_substitutor::SubstitutorValue, - }, - member::{find_index_operations, get_member_map}, - }, -}; - -use super::type_substitutor::TypeSubstitutor; -use std::collections::HashMap; - -type TplPatternMatchResult = Result<(), InferFailReason>; - -pub fn tpl_pattern_match_args( - context: &mut TplContext, - func_param_types: &[LuaType], - call_arg_types: &[LuaType], -) -> TplPatternMatchResult { - tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, false) -} - -pub fn tpl_pattern_match_args_skip_unknown( - context: &mut TplContext, - func_param_types: &[LuaType], - call_arg_types: &[LuaType], -) -> TplPatternMatchResult { - tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, true) -} - -fn tpl_pattern_match_args_inner( - context: &mut TplContext, - func_param_types: &[LuaType], - call_arg_types: &[LuaType], - skip_unknown_tpl: bool, -) -> TplPatternMatchResult { - for i in 0..func_param_types.len() { - if i >= call_arg_types.len() { - break; - } - - let func_param_type = &func_param_types[i]; - let call_arg_type = &call_arg_types[i]; - - match (func_param_type, call_arg_type) { - (LuaType::Variadic(variadic), _) => { - variadic_tpl_pattern_match(context, variadic, &call_arg_types[i..])?; - break; - } - (_, LuaType::Variadic(variadic)) => { - multi_param_tpl_pattern_match_multi_return( - context, - &func_param_types[i..], - variadic, - )?; - break; - } - _ if skip_unknown_tpl - && func_param_type.contain_tpl() - && (call_arg_type.is_any() || call_arg_type.is_unknown()) => {} - _ => { - tpl_pattern_match(context, func_param_type, call_arg_type)?; - } - } - } - - Ok(()) -} - -pub fn multi_param_tpl_pattern_match_multi_return( - context: &mut TplContext, - func_param_types: &[LuaType], - multi_return: &VariadicType, -) -> TplPatternMatchResult { - match &multi_return { - VariadicType::Base(base) => { - let mut call_arg_types = Vec::new(); - for param in func_param_types { - if param.is_variadic() { - call_arg_types.push(LuaType::Variadic(multi_return.clone().into())); - break; - } else { - call_arg_types.push(base.clone()); - } - } - - tpl_pattern_match_args(context, func_param_types, &call_arg_types)?; - } - VariadicType::Multi(_) => { - let mut call_arg_types = Vec::new(); - for (i, param) in func_param_types.iter().enumerate() { - let Some(return_type) = multi_return.get_type(i) else { - break; - }; - - if param.is_variadic() { - call_arg_types.push(LuaType::Variadic( - multi_return.get_new_variadic_from(i).into(), - )); - break; - } else { - call_arg_types.push(return_type.clone()); - } - } - - tpl_pattern_match_args(context, func_param_types, &call_arg_types)?; - } - } - - Ok(()) -} - -fn get_str_tpl_infer_type(name: &str) -> LuaType { - match name { - "unknown" => LuaType::Unknown, - "never" => LuaType::Never, - "nil" | "void" => LuaType::Nil, - "any" => LuaType::Any, - "userdata" => LuaType::Userdata, - "thread" => LuaType::Thread, - "boolean" | "bool" => LuaType::Boolean, - "string" => LuaType::String, - "integer" | "int" => LuaType::Integer, - "number" => LuaType::Number, - "io" => LuaType::Io, - "self" => LuaType::SelfInfer, - "global" => LuaType::Global, - "function" => LuaType::Function, - _ => LuaType::Ref(LuaTypeDeclId::global(&name)), - } -} - -pub fn tpl_pattern_match( - context: &mut TplContext, - pattern: &LuaType, - target: &LuaType, -) -> TplPatternMatchResult { - let target = escape_alias(context.db, target); - if !pattern.contains_tpl_node() { - return Ok(()); - } - - match pattern { - LuaType::TplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target.clone(), true); - } - } - LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target, false); - } - } - LuaType::StrTplRef(str_tpl) => { - if let LuaType::StringConst(s) = target { - let prefix = str_tpl.get_prefix(); - let suffix = str_tpl.get_suffix(); - let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.substitutor.insert_type( - str_tpl.get_tpl_id(), - get_str_tpl_infer_type(&type_name), - true, - ); - } - } - LuaType::Array(array_type) => { - array_tpl_pattern_match(context, array_type.get_base(), &target)?; - } - LuaType::TableGeneric(table_generic_params) => { - table_generic_tpl_pattern_match(context, table_generic_params, &target)?; - } - LuaType::Generic(generic) => { - generic_tpl_pattern_match(context, generic, &target)?; - } - LuaType::Union(union) => { - union_tpl_pattern_match(context, union, &target)?; - } - LuaType::DocFunction(doc_func) => { - func_tpl_pattern_match(context, doc_func, &target)?; - } - LuaType::Tuple(tuple) => { - tuple_tpl_pattern_match(context, tuple, &target)?; - } - LuaType::Object(obj) => { - object_tpl_pattern_match(context, obj, &target)?; - } - _ => {} - } - - Ok(()) -} - -pub fn constant_decay(typ: LuaType) -> LuaType { - match &typ { - LuaType::FloatConst(_) => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, - LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, - _ => typ, - } -} - -fn object_tpl_pattern_match( - context: &mut TplContext, - origin_obj: &LuaObjectType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::Object(target_object) => { - // 先匹配 fields - for (k, v) in origin_obj.get_fields().iter().sorted_by_key(|(k, _)| *k) { - let target_value = target_object.get_fields().get(k); - if let Some(target_value) = target_value { - tpl_pattern_match(context, v, target_value)?; - } - } - // 再匹配索引访问 - let target_index_access = target_object.get_index_access(); - for (origin_key, v) in origin_obj.get_index_access() { - // 先匹配 key 类型进行转换 - let target_access = target_index_access.iter().find(|(target_key, _)| { - check_type_compact(context.db, origin_key, target_key).is_ok() - }); - if let Some(target_access) = target_access { - tpl_pattern_match(context, origin_key, &target_access.0)?; - tpl_pattern_match(context, v, &target_access.1)?; - } - } - } - LuaType::TableConst(inst) => { - let owner = LuaMemberOwner::Element(inst.clone()); - object_tpl_pattern_match_member_owner_match(context, origin_obj, owner)?; - } - _ => {} - } - - Ok(()) -} - -fn object_tpl_pattern_match_member_owner_match( - context: &mut TplContext, - object: &LuaObjectType, - owner: LuaMemberOwner, -) -> TplPatternMatchResult { - let owner_type = match &owner { - LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), - LuaMemberOwner::Type(type_id) => LuaType::Ref(type_id.clone()), - _ => { - return Err(InferFailReason::None); - } - }; - - let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; - for (k, v) in members { - let resolve_key = match &k { - LuaMemberKey::Integer(i) => Some(LuaType::IntegerConst(*i)), - LuaMemberKey::Name(s) => Some(LuaType::StringConst(s.clone().into())), - _ => None, - }; - let resolve_type = match v.len() { - 0 => LuaType::Any, - 1 => v[0].typ.clone(), - _ => { - let mut types = Vec::new(); - for m in &v { - types.push(m.typ.clone()); - } - LuaType::from_vec(types) - } - }; - - // this is a workaround, I need refactor infer member map - if resolve_type.is_unknown() - && !v.is_empty() - && let Some(LuaSemanticDeclId::Member(member_id)) = &v[0].property_owner_id - { - return Err(InferFailReason::UnResolveMemberType(*member_id)); - } - - if let Some(_) = resolve_key - && let Some(field_value) = object.get_field(&k) - { - tpl_pattern_match(context, field_value, &resolve_type)?; - } - } - - Ok(()) -} - -fn array_tpl_pattern_match( - context: &mut TplContext, - base: &LuaType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::Array(target_array_type) => { - tpl_pattern_match(context, base, target_array_type.get_base())?; - } - LuaType::Tuple(target_tuple) => { - let target_base = target_tuple.cast_down_array_base(context.db); - tpl_pattern_match(context, base, &target_base)?; - } - LuaType::Object(target_object) => { - let target_base = target_object - .cast_down_array_base(context.db) - .ok_or(InferFailReason::None)?; - tpl_pattern_match(context, base, &target_base)?; - } - _ => {} - } - - Ok(()) -} - -fn table_generic_tpl_pattern_match( - context: &mut TplContext, - table_generic_params: &[LuaType], - target: &LuaType, -) -> TplPatternMatchResult { - if table_generic_params.len() != 2 { - return Err(InferFailReason::None); - } - - match target { - LuaType::TableGeneric(target_table_generic_params) => { - let min_len = table_generic_params - .len() - .min(target_table_generic_params.len()); - for i in 0..min_len { - tpl_pattern_match( - context, - &table_generic_params[i], - &target_table_generic_params[i], - )?; - } - } - LuaType::Array(target_array_base) => { - tpl_pattern_match(context, &table_generic_params[0], &LuaType::Integer)?; - tpl_pattern_match( - context, - &table_generic_params[1], - target_array_base.get_base(), - )?; - } - LuaType::Tuple(target_tuple) => { - let len = target_tuple.get_types().len(); - let mut keys = Vec::new(); - for i in 0..len { - keys.push(LuaType::IntegerConst((i as i64) + 1)); - } - - let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); - let target_base = target_tuple.cast_down_array_base(context.db); - tpl_pattern_match(context, &table_generic_params[0], &key_type)?; - tpl_pattern_match(context, &table_generic_params[1], &target_base)?; - } - LuaType::TableConst(inst) => { - let owner = LuaMemberOwner::Element(inst.clone()); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - &[], - )?; - } - LuaType::Ref(type_id) => { - let owner = LuaMemberOwner::Type(type_id.clone()); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - &[], - )?; - } - LuaType::Def(type_id) => { - let owner = LuaMemberOwner::Type(type_id.clone()); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - &[], - )?; - } - LuaType::Generic(generic) => { - let owner = LuaMemberOwner::Type(generic.get_base_type_id()); - let target_params = generic.get_params(); - table_generic_tpl_pattern_member_owner_match( - context, - table_generic_params, - owner, - target_params, - )?; - } - LuaType::Object(obj) => { - let mut keys = Vec::new(); - let mut values = Vec::new(); - for (k, v) in obj.get_fields() { - match k { - LuaMemberKey::Integer(i) => { - keys.push(LuaType::IntegerConst(*i)); - } - LuaMemberKey::Name(s) => { - keys.push(LuaType::StringConst(s.clone().into())); - } - _ => {} - }; - values.push(v.clone()); - } - for (k, v) in obj.get_index_access() { - keys.push(k.clone()); - values.push(v.clone()); - } - - let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); - let value_type = LuaType::Union(LuaUnionType::from_vec(values).into()); - tpl_pattern_match(context, &table_generic_params[0], &key_type)?; - tpl_pattern_match(context, &table_generic_params[1], &value_type)?; - } - - LuaType::Global | LuaType::Any | LuaType::Table | LuaType::Userdata => { - // too many - tpl_pattern_match(context, &table_generic_params[0], &LuaType::Any)?; - tpl_pattern_match(context, &table_generic_params[1], &LuaType::Any)?; - } - _ => {} - } - - Ok(()) -} - -// KV 表匹配 ref/def/tableconst -fn table_generic_tpl_pattern_member_owner_match( - context: &mut TplContext, - table_generic_params: &[LuaType], - owner: LuaMemberOwner, - target_params: &[LuaType], -) -> TplPatternMatchResult { - if table_generic_params.len() != 2 { - return Err(InferFailReason::None); - } - - let owner_type = match &owner { - LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()), - LuaMemberOwner::Type(type_id) => match target_params.len() { - 0 => LuaType::Ref(type_id.clone()), - _ => LuaType::Generic(Arc::new(LuaGenericType::new( - type_id.clone(), - target_params.to_vec(), - ))), - }, - _ => { - return Err(InferFailReason::None); - } - }; - - let members = get_member_map(context.db, &owner_type).ok_or(InferFailReason::None)?; - // 如果是 pairs 调用, 我们需要尝试寻找元方法, 但目前`__pairs` 被放进成员表中 - if is_pairs_call(context).unwrap_or(false) - && try_handle_pairs_metamethod(context, table_generic_params, &members).is_ok() - { - return Ok(()); - } - - let target_key_type = table_generic_params[0].clone(); - let mut keys = Vec::new(); - let mut values = Vec::new(); - for (k, v) in members { - let key_type = match k { - LuaMemberKey::Integer(i) => LuaType::IntegerConst(i), - LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), - LuaMemberKey::ExprType(typ) => typ, - _ => continue, - }; - - if !target_key_type.is_generic() - && check_type_compact(context.db, &target_key_type, &key_type).is_err() - { - continue; - } - - keys.push(key_type); - - let resolve_type = match v.len() { - 0 => LuaType::Any, - 1 => v[0].typ.clone(), - _ => { - let mut types = Vec::new(); - for m in v { - types.push(m.typ.clone()); - } - LuaType::from_vec(types) - } - }; - - values.push(resolve_type); - } - - if keys.is_empty() { - find_index_operations(context.db, &owner_type) - .ok_or(InferFailReason::None)? - .iter() - .for_each(|m| { - if target_key_type.is_generic() { - return; - } - let key_type = match &m.key { - LuaMemberKey::ExprType(typ) => typ.clone(), - _ => return, - }; - if check_type_compact(context.db, &target_key_type, &key_type).is_ok() { - keys.push(key_type); - values.push(m.typ.clone()); - } - }); - } - - let key_type = match &keys[..] { - [] => return Err(InferFailReason::None), - [first] => first.clone(), - _ => LuaType::Union(LuaUnionType::from_vec(keys).into()), - }; - let value_type = match &values[..] { - [first] => first.clone(), - _ => LuaType::Union(LuaUnionType::from_vec(values).into()), - }; - - tpl_pattern_match(context, &table_generic_params[0], &key_type)?; - tpl_pattern_match(context, &table_generic_params[1], &value_type)?; - - Ok(()) -} - -fn union_tpl_pattern_match( - context: &mut TplContext, - union: &LuaUnionType, - target: &LuaType, -) -> TplPatternMatchResult { - let mut error_count = 0; - let mut last_error = InferFailReason::None; - for u in union.into_vec() { - match tpl_pattern_match(context, &u, target) { - // 返回 ok 时并不一定匹配成功, 仅表示没有发生错误 - Ok(_) => {} - Err(e) => { - error_count += 1; - last_error = e; - } - } - } - - if error_count == union.into_vec().len() { - Err(last_error) - } else { - Ok(()) - } -} - -fn func_tpl_pattern_match( - context: &mut TplContext, - tpl_func: &LuaFunctionType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::DocFunction(target_doc_func) => { - func_tpl_pattern_match_doc_func(context, tpl_func, target_doc_func)?; - } - LuaType::Signature(signature_id) => { - let signature = context - .db - .get_signature_index() - .get(signature_id) - .ok_or(InferFailReason::None)?; - if !signature.is_resolve_return() { - return lambda_tpl_pattern::check_lambda_tpl_pattern(context, *signature_id); - } - let fake_doc_func = signature.to_doc_func_type(); - func_tpl_pattern_match_doc_func(context, tpl_func, &fake_doc_func)?; - } - _ => {} - } - - Ok(()) -} - -fn func_tpl_pattern_match_doc_func( - context: &mut TplContext, - tpl_func: &LuaFunctionType, - target_func: &LuaFunctionType, -) -> TplPatternMatchResult { - let mut tpl_func_params = tpl_func.get_params().to_vec(); - if tpl_func.is_colon_define() { - tpl_func_params.insert(0, ("self".to_string(), Some(LuaType::Any))); - } - - let mut target_func_params = target_func.get_params().to_vec(); - - if target_func.is_colon_define() { - target_func_params.insert(0, ("self".to_string(), Some(LuaType::Any))); - } - - param_type_list_pattern_match_type_list(context, &tpl_func_params, &target_func_params)?; - - let tpl_return = tpl_func.get_ret(); - let target_return = target_func.get_ret(); - return_type_pattern_match_target_type(context, tpl_return, target_return)?; - - Ok(()) -} - -fn param_type_list_pattern_match_type_list( - context: &mut TplContext, - sources: &[(String, Option)], - targets: &[(String, Option)], -) -> TplPatternMatchResult { - let type_len = sources.len(); - let mut target_offset = 0; - for i in 0..type_len { - let source = match sources.get(i) { - Some(t) => t.1.clone().unwrap_or(LuaType::Any), - None => break, - }; - - match &source { - LuaType::Variadic(inner) => { - let i = i + target_offset; - if i >= targets.len() { - if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { - let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); - } - break; - } - - if let VariadicType::Base(LuaType::TplRef(generic_tpl)) = inner.deref() { - let tpl_id = generic_tpl.get_tpl_id(); - if let Some(inferred_type_value) = context.substitutor.get(tpl_id) { - match inferred_type_value { - SubstitutorValue::Type(_) => { - continue; - } - SubstitutorValue::MultiTypes(types) => { - if types.len() > 1 { - target_offset += types.len() - 1; - } - continue; - } - SubstitutorValue::Params(params) => { - if params.len() > 1 { - target_offset += params.len() - 1; - } - continue; - } - _ => {} - } - } - } - - let mut target_rest_params = &targets[i..]; - // If the variadic parameter is not the last one, then target_rest_params should exclude the parameters that come after it. - if i + 1 < type_len { - let source_rest_len = type_len - i - 1; - if source_rest_len >= target_rest_params.len() { - continue; - } - let target_rest_len = target_rest_params.len() - source_rest_len; - target_rest_params = &target_rest_params[..target_rest_len]; - if target_rest_len > 1 { - target_offset += target_rest_len - 1; - } - } - - func_varargs_tpl_pattern_match(inner, target_rest_params, context.substitutor)?; - } - _ => { - let target = match targets.get(i + target_offset) { - Some(t) => t.1.clone().unwrap_or(LuaType::Any), - None => break, - }; - tpl_pattern_match(context, &source, &target)?; - } - } - } - - Ok(()) -} - -pub(crate) fn return_type_pattern_match_target_type( - context: &mut TplContext, - source: &LuaType, - target: &LuaType, -) -> TplPatternMatchResult { - match (source, target) { - // toooooo complex - (LuaType::Variadic(variadic_source), LuaType::Variadic(variadic_target)) => { - match variadic_target.deref() { - VariadicType::Base(target_base) => match variadic_source.deref() { - VariadicType::Base(source_base) => { - if let LuaType::TplRef(type_ref) = source_base { - let tpl_id = type_ref.get_tpl_id(); - context - .substitutor - .insert_type(tpl_id, target_base.clone(), true); - } - } - VariadicType::Multi(source_multi) => { - for ret_type in source_multi { - match ret_type { - LuaType::Variadic(inner) => { - if let VariadicType::Base(base) = inner.deref() - && let LuaType::TplRef(type_ref) = base - { - let tpl_id = type_ref.get_tpl_id(); - context.substitutor.insert_type( - tpl_id, - target_base.clone(), - true, - ); - } - - break; - } - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type( - tpl_id, - target_base.clone(), - true, - ); - } - _ => {} - } - } - } - }, - VariadicType::Multi(target_types) => { - variadic_tpl_pattern_match(context, variadic_source, target_types)?; - } - } - } - (LuaType::Variadic(variadic), _) => { - variadic_tpl_pattern_match(context, variadic, std::slice::from_ref(target))?; - } - (_, LuaType::Variadic(variadic)) => { - multi_param_tpl_pattern_match_multi_return( - context, - std::slice::from_ref(source), - variadic, - )?; - } - _ => { - tpl_pattern_match(context, source, target)?; - } - } - - Ok(()) -} - -fn func_varargs_tpl_pattern_match( - variadic: &VariadicType, - target_rest_params: &[(String, Option)], - substitutor: &mut TypeSubstitutor, -) -> TplPatternMatchResult { - match variadic { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl_ref) = base { - let tpl_id = tpl_ref.get_tpl_id(); - substitutor.insert_params( - tpl_id, - target_rest_params - .iter() - .map(|(n, t)| (n.clone(), t.clone())) - .collect(), - ); - } - } - VariadicType::Multi(_) => {} - } - - Ok(()) -} - -pub fn variadic_tpl_pattern_match( - context: &mut TplContext, - tpl: &VariadicType, - target_rest_types: &[LuaType], -) -> TplPatternMatchResult { - match tpl { - VariadicType::Base(base) => match base { - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); - } - 1 => { - // If the single argument is itself a multi-return (e.g. a function call - // returning multiple values), expand it so that `T...` receives all the - // return values rather than a single Variadic wrapper. - match &target_rest_types[0] { - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Multi(types) => match types.len() { - 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); - } - 1 => { - context.substitutor.insert_type( - tpl_id, - types[0].clone(), - true, - ); - } - _ => { - context.substitutor.insert_multi_types( - tpl_id, - types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); - } - }, - VariadicType::Base(base) => { - context.substitutor.insert_multi_base(tpl_id, base.clone()); - } - }, - arg => { - context.substitutor.insert_type(tpl_id, arg.clone(), true); - } - } - } - _ => { - context.substitutor.insert_multi_types( - tpl_id, - target_rest_types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); - } - } - } - LuaType::ConstTplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, false); - } - 1 => { - context.substitutor.insert_type( - tpl_id, - target_rest_types[0].clone(), - false, - ); - } - _ => { - context - .substitutor - .insert_multi_types(tpl_id, target_rest_types.to_vec()); - } - } - } - _ => {} - }, - VariadicType::Multi(multi) => { - for (i, ret_type) in multi.iter().enumerate() { - match ret_type { - LuaType::Variadic(inner) => { - if i < target_rest_types.len() { - variadic_tpl_pattern_match(context, inner, &target_rest_types[i..])?; - } - - break; - } - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.get(i) { - Some(t) => { - context.substitutor.insert_type(tpl_id, t.clone(), true); - } - None => { - break; - } - }; - } - _ => {} - } - } - } - } - - Ok(()) -} - -fn tuple_tpl_pattern_match( - context: &mut TplContext, - tpl_tuple: &LuaTupleType, - target: &LuaType, -) -> TplPatternMatchResult { - match target { - LuaType::Tuple(target_tuple) => { - let tpl_tuple_types = tpl_tuple.get_types(); - let target_tuple_types = target_tuple.get_types(); - let tpl_tuple_len = tpl_tuple_types.len(); - for i in 0..tpl_tuple_len { - let tpl_type = &tpl_tuple_types[i]; - - if let LuaType::Variadic(inner) = tpl_type { - let target_rest_types = &target_tuple_types[i..]; - variadic_tpl_pattern_match(context, inner, target_rest_types)?; - break; - } - - let target_type = match target_tuple_types.get(i) { - Some(t) => t, - None => break, - }; - - tpl_pattern_match(context, tpl_type, target_type)?; - } - } - LuaType::Array(target_array_base) => { - let tupl_tuple_types = tpl_tuple.get_types(); - let last_type = tupl_tuple_types.last().ok_or(InferFailReason::None)?; - if let LuaType::Variadic(inner) = last_type { - match inner.deref() { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl_ref) = base { - let tpl_id = tpl_ref.get_tpl_id(); - context - .substitutor - .insert_multi_base(tpl_id, target_array_base.get_base().clone()); - } - } - VariadicType::Multi(_) => {} - } - } - } - _ => {} - } - - Ok(()) -} - -fn escape_alias(db: &DbIndex, may_alias: &LuaType) -> LuaType { - if let LuaType::Ref(type_id) = may_alias - && let Some(type_decl) = db.get_type_index().get_type_decl(type_id) - && type_decl.is_alias() - && let Some(origin_type) = type_decl.get_alias_origin(db, None) - { - return origin_type.clone(); - } - - may_alias.clone() -} - -fn is_pairs_call(context: &mut TplContext) -> Option { - let call_expr = context.call_expr.as_ref()?; - let prefix_expr = call_expr.get_prefix_expr()?; - let semantic_decl = match prefix_expr.syntax().clone().into() { - NodeOrToken::Node(node) => infer_node_semantic_decl( - context.db, - context.cache, - node, - SemanticDeclLevel::default(), - ), - _ => None, - }?; - - let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl else { - return None; - }; - let decl = context.db.get_decl_index().get_decl(&decl_id)?; - if !context.db.get_module_index().is_std(&decl.get_file_id()) { - return None; - } - let name = decl.get_name(); - if name != "pairs" { - return None; - } - Some(true) -} - -fn try_handle_pairs_metamethod( - context: &mut TplContext, - table_generic_params: &[LuaType], - members: &HashMap>, -) -> TplPatternMatchResult { - let pairs_member = members - .get(&LuaMemberKey::Name("__pairs".into())) - .ok_or(InferFailReason::None)? - .first() - .ok_or(InferFailReason::None)?; - // 获取迭代函数返回类型 - let meta_return = match &pairs_member.typ { - LuaType::Signature(signature_id) => context - .db - .get_signature_index() - .get(signature_id) - .map(|s| s.get_return_type()), - LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), - _ => None, - } - .ok_or(InferFailReason::None)?; - - // 解析出迭代函数返回类型 - let final_return_type = match meta_return { - LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), - LuaType::Signature(signature_id) => context - .db - .get_signature_index() - .get(&signature_id) - .map(|s| s.get_return_type()), - _ => None, - }; - - if let Some(LuaType::Variadic(variadic)) = &final_return_type { - let key_type = variadic.get_type(0).ok_or(InferFailReason::None)?; - let value_type = variadic.get_type(1).ok_or(InferFailReason::None)?; - tpl_pattern_match(context, &table_generic_params[0], key_type)?; - tpl_pattern_match(context, &table_generic_params[1], value_type)?; - return Ok(()); - } - Err(InferFailReason::None) -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs new file mode 100644 index 000000000..cab2afd50 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_mapper.rs @@ -0,0 +1,364 @@ +use std::{cell::RefCell, rc::Rc}; + +use hashbrown::HashMap; + +use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId, VariadicType}; + +#[derive(Debug, Clone)] +pub enum TypeMapper { + Simple { + source: GenericTplId, + target: TypeMapperValue, + }, + Array { + mappings: Rc>, + }, + InferenceFallback { + indices: Rc>, + targets: Rc>>, + }, + Merged { + layers: Rc<[TypeMapper]>, + }, +} + +impl TypeMapper { + pub fn empty() -> Self { + Self::from_values(Vec::new(), Vec::new()) + } + + pub fn from_values(sources: Vec, targets: Vec) -> Self { + if sources.len() == 1 { + return TypeMapper::Simple { + source: sources[0], + target: targets + .into_iter() + .next() + .unwrap_or(TypeMapperValue::Type(LuaType::Any)), + }; + } + + let mut mappings = HashMap::with_capacity(sources.len().min(targets.len())); + for (source, target) in sources.into_iter().zip(targets) { + mappings.entry(source).or_insert(target); + } + TypeMapper::Array { + mappings: Rc::new(mappings), + } + } + + pub fn from_type_array(type_array: Vec) -> Self { + let sources = (0..type_array.len()) + .map(|idx| GenericTplId::Type(idx as u32)) + .collect::>(); + let targets = type_array + .into_iter() + .map(TypeMapperValue::type_value) + .collect(); + Self::from_values(sources, targets) + } + + pub fn from_uninferred(sources: Vec) -> Self { + let targets = sources + .iter() + .map(|_| TypeMapperValue::None) + .collect::>(); + Self::from_values(sources, targets) + } + + pub fn from_alias( + db: &DbIndex, + type_array: Vec, + alias_type_id: &LuaTypeDeclId, + ) -> Self { + let params = db.get_type_index().get_generic_params(alias_type_id); + let sources = type_array + .iter() + .enumerate() + .map(|(i, _)| { + params + .and_then(|params| params.get(i)) + .and_then(|param| param.tpl_id) + .unwrap_or(GenericTplId::Type(i as u32)) + }) + .collect::>(); + let targets = type_array + .into_iter() + .map(TypeMapperValue::type_value) + .collect(); + Self::from_values(sources, targets) + } + + fn collect_layers(mapper: TypeMapper, layers: &mut Vec) { + match mapper { + TypeMapper::Merged { + layers: nested_layers, + } => { + for layer in nested_layers.iter().cloned() { + Self::collect_layers(layer, layers); + } + } + other => layers.push(other), + } + } + + pub fn from_inference_fallback( + indices: Rc>, + targets: Rc>>, + ) -> Self { + TypeMapper::InferenceFallback { indices, targets } + } + + pub fn merge(mapper1: Option, mapper2: TypeMapper) -> Self { + let mut layers = Vec::new(); + if let Some(mapper1) = mapper1 { + Self::collect_layers(mapper1, &mut layers); + } + Self::collect_layers(mapper2, &mut layers); + match layers.len() { + 0 => Self::empty(), + 1 => layers.remove(0), + _ => TypeMapper::Merged { + layers: Rc::from(layers.into_boxed_slice()), + }, + } + } + + pub fn prepend(source: GenericTplId, target: LuaType, mapper: Option) -> Self { + let unary = TypeMapper::Simple { + source, + target: TypeMapperValue::type_value(target), + }; + match mapper { + Some(mapper) => Self::merge(Some(unary), mapper), + None => unary, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TypeMapperValue { + None, + Type(LuaType), + Params(Vec<(String, Option)>), + MultiTypes(Vec), + MultiBase(LuaType), +} + +impl TypeMapperValue { + pub fn type_value(ty: LuaType) -> Self { + TypeMapperValue::Type(into_ref_type(ty)) + } + + pub fn params_value(params: Vec<(String, Option)>) -> Self { + TypeMapperValue::Params( + params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(), + ) + } + + pub fn raw_type(&self) -> Option { + match self { + TypeMapperValue::Type(ty) => Some(ty.clone()), + TypeMapperValue::Params(params) => params + .first() + .and_then(|(_, ty)| ty.clone()) + .or(Some(LuaType::Unknown)), + TypeMapperValue::MultiTypes(types) => { + Some(LuaType::Variadic(VariadicType::Multi(types.clone()).into())) + } + TypeMapperValue::MultiBase(base) => Some(base.clone()), + TypeMapperValue::None => None, + } + } + + fn direct_tpl_id(&self) -> Option { + match self { + TypeMapperValue::Type(LuaType::TplRef(tpl)) + | TypeMapperValue::Type(LuaType::ConstTplRef(tpl)) => Some(tpl.get_tpl_id()), + _ => None, + } + } +} + +pub(in crate::semantic::generic) fn get_mapped_value( + tpl_id: GenericTplId, + mapper: &TypeMapper, +) -> Option { + match mapper { + TypeMapper::Simple { source, target } => { + if *source == tpl_id { + Some(target.clone()) + } else { + None + } + } + TypeMapper::Array { mappings } => mappings.get(&tpl_id).cloned(), + TypeMapper::InferenceFallback { indices, targets } => { + let index = *indices.get(&tpl_id)?; + targets.borrow().get(index).cloned() + } + TypeMapper::Merged { layers } => { + let mut current_tpl_id = tpl_id; + let mut current_value: Option = None; + let mut has_direct_value = false; + + for layer in layers.iter() { + if let Some(value) = get_mapped_value(current_tpl_id, layer) { + if let Some(mapped_tpl_id) = value.direct_tpl_id() { + current_tpl_id = mapped_tpl_id; + current_value = Some(value); + has_direct_value = true; + continue; + } + + if matches!(value, TypeMapperValue::None) { + if has_direct_value { + continue; + } + return Some(TypeMapperValue::None); + } + + return Some(value); + } + } + + current_value + } + } +} + +fn into_ref_type(ty: LuaType) -> LuaType { + match ty { + LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), + _ => ty, + } +} + +#[cfg(test)] +mod tests { + use std::{cell::RefCell, rc::Rc, sync::Arc}; + + use smol_str::SmolStr; + + use super::*; + use crate::{GenericTpl, GenericTplId, LuaArrayType}; + + fn tpl_ref(idx: u32) -> LuaType { + LuaType::TplRef(Arc::new(GenericTpl::new( + GenericTplId::Func(idx), + SmolStr::new(format!("T{}", idx)).into(), + None, + None, + ))) + } + + // 合并后的 mapper 只读直接结果, 不在查询阶段深度展开结构里的模板. + #[test] + fn merged_does_not_deep_instantiate_mapped_result() { + let first = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(LuaType::Array( + LuaArrayType::from_base_type(tpl_ref(1)).into(), + ))], + ); + let second = TypeMapper::from_values( + vec![GenericTplId::Func(1)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let mapper = TypeMapper::merge(Some(first), second); + + let mapped = + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()); + + assert_eq!( + mapped, + Some(LuaType::Array( + LuaArrayType::from_base_type(tpl_ref(1)).into() + )) + ); + } + + // 直接的 TplRef 链要继续追到后层 mapper. + #[test] + fn merged_resolves_direct_tpl_ref_through_later_mapper() { + let first = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(tpl_ref(1))], + ); + let second = TypeMapper::from_values( + vec![GenericTplId::Func(1)], + vec![TypeMapperValue::type_value(LuaType::String)], + ); + let mapper = TypeMapper::merge(Some(first), second); + + let mapped = + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()); + + assert_eq!(mapped, Some(LuaType::String)); + } + + // 未推断的槽位要保留为 None, 不能和“没有映射”混掉. + #[test] + fn inference_fallback_keeps_unresolved_slots_as_none() { + let mut indices = HashMap::new(); + indices.insert(GenericTplId::Func(0), 0); + indices.insert(GenericTplId::Func(1), 1); + let targets = Rc::new(RefCell::new(vec![ + TypeMapperValue::None, + TypeMapperValue::None, + ])); + let mapper = TypeMapper::from_inference_fallback(Rc::new(indices), targets); + + assert_eq!( + get_mapped_value(GenericTplId::Func(1), &mapper), + Some(TypeMapperValue::None) + ); + } + + // 后层显式 None 不能抹掉前层已经建立的 TplRef 链. + #[test] + fn merged_keeps_direct_tpl_ref_when_later_mapper_is_explicit_none() { + let first = TypeMapper::from_values( + vec![GenericTplId::Func(0)], + vec![TypeMapperValue::type_value(tpl_ref(1))], + ); + let second = + TypeMapper::from_values(vec![GenericTplId::Func(1)], vec![TypeMapperValue::None]); + let mapper = TypeMapper::merge(Some(first), second); + + let mapped = + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()); + + assert_eq!(mapped, Some(tpl_ref(1))); + } + + // 长链别名要能一路追到最终具体类型. + #[test] + fn long_mapper_chain_resolves_transitively() { + let mut mapper = None; + for idx in 0..64 { + let source = GenericTplId::Func(idx); + let target = if idx == 63 { + LuaType::String + } else { + tpl_ref(idx + 1) + }; + let layer = + TypeMapper::from_values(vec![source], vec![TypeMapperValue::type_value(target)]); + mapper = Some(TypeMapper::merge(mapper, layer)); + } + + let mapper = mapper.expect("mapper"); + assert_eq!( + get_mapped_value(GenericTplId::Func(0), &mapper).and_then(|value| value.raw_type()), + Some(LuaType::String) + ); + assert_eq!( + get_mapped_value(GenericTplId::Func(31), &mapper).and_then(|value| value.raw_type()), + Some(LuaType::String) + ); + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs deleted file mode 100644 index b045bda1d..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ /dev/null @@ -1,316 +0,0 @@ -use hashbrown::{HashMap, HashSet}; - -use super::tpl_pattern::constant_decay; -use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum UninferredTplPolicy { - /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. - Fallback, - /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. - PreserveTplRef, -} - -#[derive(Debug)] -pub struct GenericInstantiateContext<'a> { - pub db: &'a DbIndex, - pub substitutor: &'a TypeSubstitutor, - policy: UninferredTplPolicy, -} - -impl<'a> GenericInstantiateContext<'a> { - pub fn new(db: &'a DbIndex, substitutor: &'a TypeSubstitutor) -> Self { - Self { - db, - substitutor, - policy: UninferredTplPolicy::Fallback, - } - } - - pub(super) fn with_policy(&self, policy: UninferredTplPolicy) -> GenericInstantiateContext<'a> { - GenericInstantiateContext { - db: self.db, - substitutor: self.substitutor, - policy, - } - } - - pub fn with_substitutor<'b>( - &'b self, - substitutor: &'b TypeSubstitutor, - ) -> GenericInstantiateContext<'b> { - GenericInstantiateContext { - db: self.db, - substitutor, - policy: self.policy, - } - } - - pub fn should_preserve_tpl_ref(&self) -> bool { - self.policy == UninferredTplPolicy::PreserveTplRef - } -} - -#[derive(Debug, Clone)] -pub struct TypeSubstitutor { - tpl_replace_map: HashMap, - alias_type_id: Option, - self_type: Option, -} - -impl Default for TypeSubstitutor { - fn default() -> Self { - Self::new() - } -} - -impl TypeSubstitutor { - pub fn new() -> Self { - Self { - tpl_replace_map: HashMap::new(), - alias_type_id: None, - self_type: None, - } - } - - pub fn from_type_array(type_array: Vec) -> Self { - let mut tpl_replace_map = HashMap::new(); - for (i, ty) in type_array.into_iter().enumerate() { - tpl_replace_map.insert( - GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), - ); - } - Self { - tpl_replace_map, - alias_type_id: None, - self_type: None, - } - } - - pub fn from_alias(type_array: Vec, alias_type_id: LuaTypeDeclId) -> Self { - let mut tpl_replace_map = HashMap::new(); - for (i, ty) in type_array.into_iter().enumerate() { - tpl_replace_map.insert( - GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), - ); - } - Self { - tpl_replace_map, - alias_type_id: Some(alias_type_id), - self_type: None, - } - } - - pub fn add_need_infer_tpls(&mut self, tpl_ids: HashSet) { - for tpl_id in tpl_ids { - // conditional infer id 只属于条件类型内部匹配, 不参与普通调用/类型泛型推导. - if tpl_id.is_conditional_infer() { - continue; - } - - self.tpl_replace_map - .entry(tpl_id) - .or_insert(SubstitutorValue::None); - } - } - - pub fn is_infer_all_tpl(&self) -> bool { - for value in self.tpl_replace_map.values() { - if let SubstitutorValue::None = value { - return false; - } - } - true - } - - pub fn insert_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { - // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. - if tpl_id.is_conditional_infer() { - return; - } - - self.insert_type_value(tpl_id, SubstitutorTypeValue::new(replace_type, decay)); - } - - pub(super) fn replace_type( - &mut self, - tpl_id: GenericTplId, - replace_type: LuaType, - decay: bool, - ) { - if tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), - ); - } - - pub fn insert_conditional_infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { - // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. - if !tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, false)), - ); - } - - fn insert_type_value(&mut self, tpl_id: GenericTplId, value: SubstitutorTypeValue) { - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Type(value)); - } - - fn can_insert_type(&self, tpl_id: GenericTplId) -> bool { - if let Some(value) = self.tpl_replace_map.get(&tpl_id) { - return value.is_none(); - } - - true - } - - pub fn insert_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - let params = params - .into_iter() - .map(|(name, ty)| (name, ty.map(into_ref_type))) - .collect(); - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Params(params)); - } - - pub fn insert_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiTypes(types)); - } - - pub fn insert_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); - } - - pub fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { - self.tpl_replace_map.get(&tpl_id) - } - - pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { - match self.tpl_replace_map.get(&tpl_id) { - Some(SubstitutorValue::Type(ty)) => Some(ty.raw()), - _ => None, - } - } - - pub fn check_recursion(&self, type_id: &LuaTypeDeclId) -> bool { - if let Some(alias_type_id) = &self.alias_type_id - && alias_type_id == type_id - { - return true; - } - - false - } - - pub fn add_self_type(&mut self, self_type: LuaType) { - self.self_type = Some(self_type); - } - - pub fn get_self_type(&self) -> Option<&LuaType> { - self.self_type.as_ref() - } -} - -#[derive(Debug, Clone)] -pub struct SubstitutorTypeValue { - raw: LuaType, - decayed: DecayedType, -} - -#[derive(Debug, Clone)] -enum DecayedType { - Same, - Cached(LuaType), -} - -impl SubstitutorTypeValue { - pub fn new(raw: LuaType, decay: bool) -> Self { - let raw = into_ref_type(raw); - let decayed = if decay { - let decayed = into_ref_type(constant_decay(raw.clone())); - if decayed == raw { - DecayedType::Same - } else { - DecayedType::Cached(decayed) - } - } else { - DecayedType::Same - }; - Self { raw, decayed } - } - - pub fn raw(&self) -> &LuaType { - &self.raw - } - - pub fn default(&self) -> &LuaType { - match &self.decayed { - DecayedType::Same => &self.raw, - DecayedType::Cached(decayed) => decayed, - } - } -} - -#[derive(Debug, Clone)] -pub enum SubstitutorValue { - None, - Type(SubstitutorTypeValue), - Params(Vec<(String, Option)>), - MultiTypes(Vec), - MultiBase(LuaType), -} - -impl SubstitutorValue { - pub fn is_none(&self) -> bool { - matches!(self, SubstitutorValue::None) - } -} - -fn into_ref_type(ty: LuaType) -> LuaType { - match ty { - LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), - _ => ty, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index cd1360cf2..4d000ba9c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -8,7 +8,7 @@ use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; -use crate::semantic::overload_resolve::callable_accepts_args; +use crate::semantic::overload_resolve::{callable_accepts_args, collect_callable_overload_groups}; use crate::{ AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, @@ -17,14 +17,14 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{ - TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, - instantiate_doc_function, - }, + generic::{TypeMapper, get_tpl_ref_extend_type}, infer::narrow::get_type_at_call_expr_inline_cast, }, }; -use crate::{build_self_type, infer_self_type, instantiate_func_generic, semantic::infer_expr}; +use crate::{ + TplResolvePolicy, build_self_type, infer_call_func_generic, infer_self_type, + instantiate_type_generic_full, semantic::infer_expr, +}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -134,14 +134,13 @@ pub fn infer_call_expr_func( }; let result = if let Ok(func_ty) = result { - let func_ty = match func_ty.get_ret() { - LuaType::Call(_) => { - match instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { - Ok(func_ty) => Arc::new(func_ty), - Err(_) => func_ty, - } + let func_ty = if func_ty.get_ret().contain_tpl() || func_ty.get_ret().is_call() { + match infer_call_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { + Ok(func_ty) => Arc::new(func_ty), + Err(_) => func_ty, } - _ => func_ty, + } else { + func_ty }; let func_ret = func_ty.get_ret(); @@ -223,7 +222,7 @@ fn infer_doc_function( call_expr: LuaCallExpr, ) -> InferCallFuncResult { if func.contain_tpl() { - let result = instantiate_func_generic(db, cache, func, call_expr)?; + let result = infer_call_func_generic(db, cache, func, call_expr)?; return Ok(Arc::new(result)); } @@ -270,10 +269,10 @@ fn filter_callable_overloads_by_call_args( } let has_tpls = !callable_tpls.is_empty(); - let mut substitutor = TypeSubstitutor::new(); - substitutor.add_need_infer_tpls(callable_tpls); + let mapper = TypeMapper::from_uninferred(callable_tpls.into_iter().collect()); let match_func = if has_tpls { - match instantiate_doc_function(db, func, &substitutor) { + let func_ty = LuaType::DocFunction(func.clone()); + match instantiate_type_generic(db, &func_ty, &mapper) { LuaType::DocFunction(doc_func) => doc_func, _ => func.clone(), } @@ -362,14 +361,19 @@ fn infer_type_doc_function( }; if has_generic_tpl { - let result = instantiate_func_generic(db, cache, &f, call_expr.clone())?; + let result = infer_call_func_generic(db, cache, &f, call_expr.clone())?; overloads.push(Arc::new(result)); } else if f.contain_self() { - let mut substitutor = TypeSubstitutor::new(); + let mapper = TypeMapper::empty(); let self_type = build_self_type(db, call_expr_type); - substitutor.add_self_type(self_type); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, &f, &substitutor) - { + let func_ty = LuaType::DocFunction(f.clone()); + if let LuaType::DocFunction(f) = instantiate_type_generic_full( + db, + &func_ty, + &mapper, + Some(&self_type), + TplResolvePolicy::Fallback, + ) { overloads.push(f); } } else { @@ -405,7 +409,7 @@ fn infer_generic_type_doc_function( let type_id = generic.get_base_type_id(); infer_guard.check(&type_id)?; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); let type_decl = db .get_type_index() @@ -413,7 +417,7 @@ fn infer_generic_type_doc_function( .ok_or_else(|| InferFailReason::UnResolveTypeDecl(type_id.clone()))?; if type_decl.is_alias() { let origin_type = type_decl - .get_alias_origin(db, Some(&substitutor)) + .get_alias_origin(db, Some(&mapper)) .ok_or(InferFailReason::None)?; return infer_call_expr_func( db, @@ -439,7 +443,7 @@ fn infer_generic_type_doc_function( let func = operator.get_operator_func(db); match func { LuaType::DocFunction(_) => { - let new_f = instantiate_type_generic(db, &func, &substitutor); + let new_f = instantiate_type_generic(db, &func, &mapper); if let LuaType::DocFunction(f) = new_f { overloads.push(f.clone()); } @@ -454,7 +458,7 @@ fn infer_generic_type_doc_function( } let typ = LuaType::DocFunction(signature.to_call_operator_func_type()); - let new_f = instantiate_type_generic(db, &typ, &substitutor); + let new_f = instantiate_type_generic(db, &typ, &mapper); if let LuaType::DocFunction(f) = new_f { overloads.push(f.clone()); } @@ -900,7 +904,31 @@ mod tests { ); assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); - assert_eq!(ws.expr_ty("payload"), ws.ty("string")); + assert_eq!(ws.expr_ty("payload"), LuaType::Unknown); + } + + #[test] + fn test_top_level_generic_literal_keeps_function_param_and_return_consistent() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + local function id(value) end + + id("hello") + "#, + ); + let call_expr = ws.get_node::(file_id); + let semantic_model = ws.analysis.compilation.get_semantic_model(file_id).unwrap(); + let func = semantic_model + .infer_call_expr_func(call_expr, None) + .unwrap(); + + let param_ty = func.get_params()[0].1.clone().unwrap(); + assert_eq!(ws.humanize_type(param_ty), "\"hello\""); + assert_eq!(ws.humanize_type(func.get_ret().clone()), "\"hello\""); } #[test] diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 1f3ad09ce..5e49a7bee 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -18,7 +18,7 @@ use crate::{ enum_variable_is_param, get_keyof_members, get_tpl_ref_extend_type, semantic::{ InferGuard, - generic::{TypeSubstitutor, instantiate_type_generic}, + generic::{TypeMapper, instantiate_type_generic}, infer::{ VarRefId, infer_index::infer_array::{ @@ -663,7 +663,7 @@ fn infer_generic_members_from_super_generics( db: &DbIndex, cache: &mut LuaInferCache, type_decl_id: &LuaTypeDeclId, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> Option { @@ -677,7 +677,7 @@ fn infer_generic_members_from_super_generics( let type_decl_id = type_decl.get_id(); if let Some(super_types) = type_index.get_super_types(&type_decl_id) { super_types.iter().find_map(|super_type| { - let super_type = instantiate_type_generic(db, super_type, substitutor); + let super_type = instantiate_type_generic(db, super_type, mapper); infer_member_by_lookup(db, cache, &super_type, lookup, &infer_guard.fork()).ok() }) } else { @@ -694,14 +694,22 @@ fn infer_generic_member( ) -> InferResult { let base_type = generic_type.get_base_type(); - let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); - if let LuaType::Ref(base_type_decl_id) = &base_type { let type_index = db.get_type_index(); - if let Some(type_decl) = type_index.get_type_decl(base_type_decl_id) - && type_decl.is_alias() - && let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) + let Some(type_decl) = type_index.get_type_decl(base_type_decl_id) else { + return Err(InferFailReason::None); + }; + let generic_params = generic_type.get_params(); + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, generic_params.clone(), base_type_decl_id) + } else { + TypeMapper::from_type_array(generic_params.clone()) + }; + + if type_decl.is_alias() + && let Some(origin_type) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &mapper)) { return infer_member_by_lookup(db, cache, &origin_type, lookup, &infer_guard.fork()); } @@ -710,18 +718,19 @@ fn infer_generic_member( db, cache, base_type_decl_id, - &substitutor, + &mapper, lookup, infer_guard, ); if let Some(result) = result { return Ok(result); } - } - let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; + let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; + return Ok(instantiate_type_generic(db, &member_type, &mapper)); + } - Ok(instantiate_type_generic(db, &member_type, &substitutor)) + infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard) } fn infer_instance_member( @@ -986,17 +995,24 @@ fn infer_member_by_index_generic( return Err(InferFailReason::None); }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index .get_type_decl(&type_decl_id) .ok_or(InferFailReason::None)?; + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, generic_params.clone(), &type_decl_id) + } else { + TypeMapper::from_type_array(generic_params.clone()) + }; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin_type) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &mapper)) + { return infer_member_by_operator_key_type( db, cache, - &instantiate_type_generic(db, &origin_type, &substitutor), + &origin_type, key_type, &infer_guard.fork(), ); @@ -1013,9 +1029,9 @@ fn infer_member_by_index_generic( .get_operator(index_operator_id) .ok_or(InferFailReason::None)?; let operand = index_operator.get_operand(db); - let instianted_operand = instantiate_type_generic(db, &operand, &substitutor); + let instianted_operand = instantiate_type_generic(db, &operand, &mapper); let return_type = - instantiate_type_generic(db, &index_operator.get_result(db)?, &substitutor); + instantiate_type_generic(db, &index_operator.get_result(db)?, &mapper); let result = infer_index_metamethod_by_key_type(db, key_type, &instianted_operand, &return_type); @@ -1038,7 +1054,7 @@ fn infer_member_by_index_generic( let result = infer_member_by_operator_key_type( db, cache, - &instantiate_type_generic(db, &super_type, &substitutor), + &instantiate_type_generic(db, &super_type, &mapper), key_type, &infer_guard.fork(), ); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs index 357b9ae8f..9e60638fa 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs @@ -351,18 +351,16 @@ pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult { .ok_or(InferFailReason::None)?; if decl_ids.len() == 1 { let id = decl_ids[0]; - let typ = match db.get_type_index().get_type_cache(&id.into()) { - Some(type_cache) => type_cache.as_type().clone(), - None => return Err(InferFailReason::UnResolveDeclType(id)), - }; - // todo: 不置为 Unknown 有可能引用泛型函数中的泛型参数导致泄露, 但这样会导致丢失类型, 我们可能需要更好的办法去处理 - return if !typ.is_generic() && typ.contain_tpl() { - // This decl is located in a generic function, - // and is type contains references to generic variables - // of this function. - Ok(LuaType::Unknown) - } else { - Ok(typ) + return match db.get_type_index().get_type_cache(&id.into()) { + Some(type_cache) => { + let typ = type_cache.as_type(); + if is_bare_leaked_tpl(typ) { + Ok(LuaType::Unknown) + } else { + Ok(typ.clone()) + } + } + None => Err(InferFailReason::UnResolveDeclType(id)), }; } @@ -382,10 +380,7 @@ pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult { Some(type_cache) => { let typ = type_cache.as_type(); - if typ.contain_tpl() { - // This decl is located in a generic function, - // and is type contains references to generic variables - // of this function. + if is_bare_leaked_tpl(typ) { continue; } @@ -461,3 +456,10 @@ pub fn find_self_decl_or_member_id( _ => None, } } + +fn is_bare_leaked_tpl(typ: &LuaType) -> bool { + matches!( + typ, + LuaType::TplRef(_) | LuaType::ConstTplRef(_) | LuaType::StrTplRef(_) | LuaType::SelfInfer + ) +} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index e915915b1..7b9d10f3a 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -16,7 +16,7 @@ use crate::{ var_ref_id::get_var_expr_var_ref_id, }, }, - semantic::instantiate_func_generic, + semantic::infer_call_func_generic, }; pub fn get_type_at_call_expr( @@ -226,7 +226,7 @@ fn get_type_guard_call_info( let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { let Ok(inst_func) = cache.with_no_flow(|cache| { - instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) + infer_call_func_generic(db, cache, func_type.as_ref(), call_expr) }) else { return Ok(None); }; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 78c464852..61728c1e1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,7 +7,7 @@ use crate::{ LuaSignature, LuaType, TypeOps, semantic::{ infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, - instantiate_func_generic, + infer_call_func_generic, }, }; @@ -579,7 +579,7 @@ fn instantiate_return_rows( return_type.clone(), ); match cache - .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) + .with_no_flow(|cache| infer_call_func_generic(db, cache, &func, call_expr.clone())) { Ok(instantiated) => instantiated.get_ret().clone(), Err(_) => return_type, diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_index.rs b/crates/emmylua_code_analysis/src/semantic/member/find_index.rs index 20f8d4422..091031564 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_index.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_index.rs @@ -3,11 +3,8 @@ use hashbrown::{HashMap, HashSet}; use crate::{ DbIndex, InFiled, InferGuardRef, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, - LuaType, LuaTypeDeclId, LuaUnionType, TypeOps, - semantic::{ - InferGuard, - generic::{TypeSubstitutor, instantiate_type_generic}, - }, + LuaType, LuaTypeDeclId, LuaUnionType, TypeMapper, TypeOps, + semantic::{InferGuard, generic::instantiate_type_generic}, }; use super::{FindMembersResult, LuaMemberInfo, intersect_member_types}; @@ -310,13 +307,13 @@ fn find_index_generic( }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); + let mapper = TypeMapper::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index.get_type_decl(&type_decl_id)?; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { - let instantiated_type = instantiate_type_generic(db, &origin_type, &substitutor); + if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&mapper)) { + let instantiated_type = instantiate_type_generic(db, &origin_type, &mapper); return find_index_operations_guard(db, &instantiated_type, infer_guard); } return None; @@ -332,11 +329,11 @@ fn find_index_generic( for index_operator_id in index_operator_ids { if let Some(index_operator) = operator_index.get_operator(index_operator_id) { let operand = index_operator.get_operand(db); - let instantiated_operand = instantiate_type_generic(db, &operand, &substitutor); + let instantiated_operand = instantiate_type_generic(db, &operand, &mapper); if let Ok(return_type) = index_operator.get_result(db) { let instantiated_return_type = - instantiate_type_generic(db, &return_type, &substitutor); + instantiate_type_generic(db, &return_type, &mapper); members.push(LuaMemberInfo { property_owner_id: None, @@ -353,7 +350,7 @@ fn find_index_generic( // Find index operations in super types if let Some(supers) = type_index.get_super_types(&type_decl_id) { for super_type in supers { - let instantiated_super = instantiate_type_generic(db, &super_type, &substitutor); + let instantiated_super = instantiate_type_generic(db, &super_type, &mapper); if let Some(super_members) = find_index_operations_guard(db, &instantiated_super, infer_guard) { diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs index 606ba27b5..a166687bf 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs @@ -8,7 +8,7 @@ use crate::{ LuaTypeDeclId, LuaUnionType, semantic::{ InferGuard, - generic::{TypeSubstitutor, instantiate_type_generic}, + generic::{TypeMapper, instantiate_type_generic}, }, }; @@ -84,7 +84,7 @@ pub fn find_members_with_key_in_scope( struct FindMembersContext { file_id: FileId, infer_guard: InferGuardRef, - substitutor: Option, + mapper: Option, } impl FindMembersContext { @@ -92,14 +92,14 @@ impl FindMembersContext { Self { file_id, infer_guard, - substitutor: None, + mapper: None, } } - fn with_substitutor(&self, substitutor: TypeSubstitutor) -> Self { + fn with_mapper(&self, mapper: TypeMapper) -> Self { Self { file_id: self.file_id, infer_guard: self.infer_guard.clone(), - substitutor: Some(substitutor), + mapper: Some(mapper), } } @@ -107,13 +107,13 @@ impl FindMembersContext { Self { file_id: self.file_id, infer_guard: self.infer_guard.fork(), - substitutor: self.substitutor.clone(), + mapper: self.mapper.clone(), } } fn instantiate_type(&self, db: &DbIndex, ty: &LuaType) -> LuaType { - if let Some(substitutor) = &self.substitutor { - instantiate_type_generic(db, ty, substitutor) + if let Some(mapper) = &self.mapper { + instantiate_type_generic(db, ty, mapper) } else { ty.clone() } @@ -495,17 +495,24 @@ fn find_generic_members( .iter() .map(|param| ctx.instantiate_type(db, param)) .collect(); - let substitutor = TypeSubstitutor::from_type_array(instantiated_params); let type_decl = db.get_type_index().get_type_decl(&base_ref_id)?; - let ctx_with_substitutor = ctx.with_substitutor(substitutor.clone()); - if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { - return find_members_guard(db, &origin, &ctx_with_substitutor, filter); + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, instantiated_params, &base_ref_id) + } else { + TypeMapper::from_type_array(instantiated_params) + }; + let ctx_with_mapper = ctx.with_mapper(mapper.clone()); + if let Some(origin) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &mapper)) + { + return find_members_guard(db, &origin, &ctx_with_mapper, filter); } find_members_guard( db, &LuaType::Ref(base_ref_id.clone()), - &ctx_with_substitutor, + &ctx_with_mapper, filter, ) } diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs index cadd3988e..e322cd9e3 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs @@ -4,9 +4,8 @@ use smol_str::SmolStr; use crate::{ DbIndex, InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaMemberKey, - LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, - check_type_compact, - semantic::generic::{TypeSubstitutor, instantiate_type_generic}, + LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeDeclId, TypeMapper, TypeOps, + check_type_compact, semantic::generic::instantiate_type_generic, }; use super::{RawGetMemberTypeResult, get_buildin_type_map_type_id}; @@ -214,17 +213,21 @@ fn infer_generic_raw_member_type( ) -> RawGetMemberTypeResult { let base_ref_id = generic_type.get_base_type_id_ref(); let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); let type_decl = db .get_type_index() .get_type_decl(&base_ref_id) .ok_or(InferFailReason::None)?; + let mapper = if type_decl.is_alias() { + TypeMapper::from_alias(db, generic_params.clone(), &base_ref_id) + } else { + TypeMapper::from_type_array(generic_params.clone()) + }; - if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { - return infer_raw_member_type(db, &origin, member_key); + if let Some(origin) = type_decl.get_alias_origin(db, Some(&mapper)) { + return infer_raw_member_type_guard(db, &origin, member_key, infer_guard); } let base_ref_type = LuaType::Ref(base_ref_id.clone()); let result = infer_raw_member_type_guard(db, &base_ref_type, member_key, infer_guard)?; - Ok(instantiate_type_generic(db, &result, &substitutor)) + Ok(instantiate_type_generic(db, &result, &mapper)) } diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs new file mode 100644 index 000000000..19d139395 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs @@ -0,0 +1,83 @@ +use std::sync::Arc; + +use hashbrown::HashSet; + +use crate::db_index::{DbIndex, LuaFunctionType, LuaType, LuaTypeDeclId}; + +use super::super::{generic::TypeMapper, infer::InferFailReason}; + +pub(crate) fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(type_id); + result?; + } + LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + let mapper = TypeMapper::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); + return Ok(()); + }; + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&mapper)) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(&type_id); + result?; + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + _ => {} + } + + Ok(()) +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index a6447a91c..203d97870 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,3 +1,4 @@ +mod collect_callable_overloads; mod resolve_signature_by_args; use std::sync::Arc; @@ -8,10 +9,11 @@ use crate::db_index::{DbIndex, LuaFunctionType, LuaType}; use super::{ LuaInferCache, - generic::instantiate_func_generic, + generic::infer_call_func_generic, infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; +pub(crate) use collect_callable_overloads::collect_callable_overload_groups; pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; pub fn resolve_signature( @@ -78,7 +80,7 @@ fn resolve_signature_by_generic( ) -> InferCallFuncResult { let mut instantiate_funcs = Vec::new(); for func in overloads { - let instantiate_func = instantiate_func_generic(db, cache, &func, call_expr.clone())?; + let instantiate_func = infer_call_func_generic(db, cache, &func, call_expr.clone())?; instantiate_funcs.push(Arc::new(instantiate_func)); } resolve_signature_by_args( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs index e9827448b..4f7421789 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs @@ -13,8 +13,7 @@ use table_generic_check::check_table_generic_type_compact; use tuple_type_check::check_tuple_type_compact; use crate::{ - LuaType, LuaUnionType, TypeSubstitutor, - semantic::type_check::type_check_context::TypeCheckContext, + LuaType, LuaUnionType, TypeMapper, semantic::type_check::type_check_context::TypeCheckContext, }; use super::{ @@ -37,9 +36,9 @@ pub fn check_complex_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let mapper = + TypeMapper::from_alias(context.db, generic.get_params().clone(), &base_id); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&mapper)) { return check_general_type_compact( context, source, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs index afee3eddc..796677345 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs @@ -1,5 +1,5 @@ use crate::{ - TypeSubstitutor, + TypeMapper, db_index::{LuaFunctionType, LuaOperatorMetaMethod, LuaSignatureId, LuaType, LuaTypeDeclId}, semantic::type_check::type_check_context::TypeCheckContext, }; @@ -23,9 +23,9 @@ pub fn check_doc_func_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let mapper = + TypeMapper::from_alias(context.db, generic.get_params().clone(), &base_id); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&mapper)) { return check_general_type_compact( context, &LuaType::DocFunction(source_func.clone().into()), diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs index 0929c7ed5..e12e0c949 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - LuaGenericType, LuaMemberOwner, LuaType, LuaTypeCache, LuaTypeDeclId, RenderLevel, - TypeSubstitutor, complete_type_generic_args_in_type, humanize_type, instantiate_type_generic, + LuaGenericType, LuaMemberOwner, LuaType, LuaTypeCache, LuaTypeDeclId, RenderLevel, TypeMapper, + complete_type_generic_args_in_type, humanize_type, instantiate_type_generic, semantic::{member::find_members, type_check::type_check_context::TypeCheckContext}, }; @@ -24,9 +24,10 @@ pub fn check_generic_type_compact( .get_type_decl(&source_generic.get_base_type_id()) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(source_generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let mapper = + TypeMapper::from_alias(context.db, source_generic.get_params().clone(), &base_id); + if let Some(alias_ref) = decl.get_alias_ref() { + let alias_origin = instantiate_type_generic(context.db, alias_ref, &mapper); return check_general_type_compact( context, &alias_origin, @@ -61,10 +62,9 @@ pub fn check_generic_type_compact( { for mut super_type in supers { if super_type.contain_tpl() { - let substitutor = - TypeSubstitutor::from_type_array(compact_generic.get_params().clone()); - super_type = - instantiate_type_generic(context.db, &super_type, &substitutor); + let mapper = + TypeMapper::from_type_array(compact_generic.get_params().clone()); + super_type = instantiate_type_generic(context.db, &super_type, &mapper); } let result = check_generic_type_compact( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 6f994cb0c..eb15bfbbe 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use crate::{ - DbIndex, LuaType, LuaTypeDeclId, TypeSubstitutor, VariadicType, + DbIndex, LuaType, LuaTypeDeclId, TypeMapper, VariadicType, semantic::type_check::{ is_sub_type_of, type_check_context::{TypeCheckCheckLevel, TypeCheckContext}, @@ -310,11 +310,9 @@ pub fn check_simple_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = - decl.get_alias_origin(context.db, Some(&substitutor)) - { + let mapper = + TypeMapper::from_alias(context.db, generic.get_params().clone(), &base_id); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&mapper)) { return check_general_type_compact( context, source, diff --git a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs index 0e9050f4b..a9c77a058 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs @@ -2,7 +2,7 @@ use emmylua_code_analysis::{ DbIndex, GenericTplId, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, - RenderLevel, SemanticDeclLevel, TypeSubstitutor, build_call_constraint_context, get_real_type, + RenderLevel, SemanticDeclLevel, TypeMapper, build_call_constraint_context, get_real_type, instantiate_type_generic, normalize_constraint_type, }; use emmylua_parser::{ @@ -340,16 +340,16 @@ fn infer_call_arg_list( param_idx += 1; } } - let constraint_substitutor = build_call_constraint_context(&builder.semantic_model, &call_expr) - .map(|ctx| ctx.substitutor); - let substitutor = constraint_substitutor.as_ref(); + let constraint_mapper = + build_call_constraint_context(&builder.semantic_model, &call_expr).map(|ctx| ctx.mapper); + let mapper = constraint_mapper.as_ref(); let typ = call_expr_func .get_params() .get(param_idx)? .1 .clone() .unwrap_or(LuaType::Unknown); - let typ = resolve_param_type(builder, typ, substitutor); + let typ = resolve_param_type(builder, typ, mapper); let mut types = Vec::new(); types.push(typ); push_function_overloads_param( @@ -357,7 +357,7 @@ fn infer_call_arg_list( &call_expr, call_expr_func.get_params(), param_idx, - substitutor, + mapper, &mut types, ); Some(types.into_iter().unique().collect()) // 需要去重 @@ -366,22 +366,22 @@ fn infer_call_arg_list( fn resolve_param_type( builder: &CompletionBuilder, mut typ: LuaType, - substitutor: Option<&TypeSubstitutor>, + mapper: Option<&TypeMapper>, ) -> LuaType { let db = builder.semantic_model.get_db(); - if let Some(substitutor) = substitutor { - typ = apply_substitutor_to_type(db, typ, substitutor); + if let Some(mapper) = mapper { + typ = apply_mapper_to_type(db, typ, mapper); } normalize_constraint_type(db, typ) } -fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubstitutor) -> LuaType { +fn apply_mapper_to_type(db: &DbIndex, typ: LuaType, mapper: &TypeMapper) -> LuaType { if let LuaType::Call(alias_call) = &typ { if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf { let operands = alias_call .get_operands() .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) + .map(|operand| instantiate_type_generic(db, operand, mapper)) .collect::>(); return LuaType::Call(Arc::new(LuaAliasCallType::new( alias_call.get_call_kind(), @@ -389,16 +389,16 @@ fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubst ))); } } - if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, substitutor) { + if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, mapper) { return alias_call; } - instantiate_type_generic(db, &typ, substitutor) + instantiate_type_generic(db, &typ, mapper) } fn rebuild_keyof_alias_call( db: &DbIndex, original_type: &LuaType, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> Option { let tpl = match original_type { LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl, @@ -415,7 +415,7 @@ fn rebuild_keyof_alias_call( let operands = alias_call .get_operands() .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) + .map(|operand| instantiate_type_generic(db, operand, mapper)) .collect::>(); Some(LuaType::Call(Arc::new(LuaAliasCallType::new( alias_call.get_call_kind(), @@ -428,7 +428,7 @@ fn push_function_overloads_param( call_expr: &LuaCallExpr, call_params: &[(String, Option)], param_idx: usize, - substitutor: Option<&TypeSubstitutor>, + mapper: Option<&TypeMapper>, types: &mut Vec, ) -> Option<()> { let member_index = builder.semantic_model.get_db().get_member_index(); @@ -508,7 +508,7 @@ fn push_function_overloads_param( // 添加匹配的参数类型 if let Some(param_type) = overload_params.get(param_idx).and_then(|p| p.1.clone()) { - let param_type = resolve_param_type(builder, param_type, substitutor); + let param_type = resolve_param_type(builder, param_type, mapper); types.push(param_type); } } diff --git a/crates/emmylua_ls/src/handlers/definition/goto_function.rs b/crates/emmylua_ls/src/handlers/definition/goto_function.rs index f60d9b395..971c74aae 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_function.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_function.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ LuaCompilation, LuaDeclId, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaSignatureId, - LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic, + LuaType, SemanticDeclLevel, SemanticModel, infer_call_func_generic, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, LuaSyntaxToken, LuaTokenKind, @@ -291,7 +291,7 @@ pub fn compare_function_types( call_expr: &LuaCallExpr, ) -> Option { if func.contain_tpl() { - let instantiated_func = instantiate_func_generic( + let instantiated_func = infer_call_func_generic( semantic_model.get_db(), &mut semantic_model.get_cache().borrow_mut(), func, diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 2402e078e..7e046bb05 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -2,9 +2,9 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, - LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, - instantiate_func_generic, try_extract_signature_id_from_field, + LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, TypeMapper, + VariadicType, humanize_type, infer_call_expr_func, infer_call_func_generic, + instantiate_type_generic, try_extract_signature_id_from_field, }; use crate::handlers::hover::{ @@ -104,7 +104,7 @@ fn build_function_call_hover( signature.get_type_params(), signature.get_return_type(), ); - let instantiated_signature = instantiate_func_generic( + let instantiated_signature = infer_call_func_generic( db, &mut builder.semantic_model.get_cache().borrow_mut(), &base_function, @@ -188,8 +188,8 @@ fn build_function_define_hover( _ => None, }; - if let Some(substitutor) = &builder.substitutor { - if let Some(lua_func) = hover_instantiate_function_type(db, &typ, substitutor) { + if let Some(mapper) = &builder.mapper { + if let Some(lua_func) = hover_instantiate_function_type(db, &typ, mapper) { typ = LuaType::DocFunction(lua_func); } } @@ -486,7 +486,7 @@ fn instantiate_call_return_overloads( row_return_type, ); let instantiated_row = - instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) + infer_call_func_generic(db, &mut cache, &row_function, call_expr.clone()) .ok() .map(|func| match func.get_ret() { LuaType::Variadic(variadic) => match variadic.as_ref() { @@ -696,14 +696,14 @@ fn function_member_is_field(db: &DbIndex, semantic_decls: &[(LuaSemanticDeclId, fn hover_instantiate_function_type( db: &DbIndex, typ: &LuaType, - substitutor: &TypeSubstitutor, + mapper: &TypeMapper, ) -> Option> { if !typ.contain_tpl() { return None; } match typ { - LuaType::DocFunction(f) => { - if let LuaType::DocFunction(f) = instantiate_doc_function(db, f, substitutor) { + LuaType::DocFunction(_) => { + if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, mapper) { Some(f) } else { None diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index d4cd08773..642b6c2e6 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ - GenericTplId, LuaCompilation, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, - RenderLevel, SemanticModel, TypeSubstitutor, + LuaCompilation, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, RenderLevel, + SemanticModel, TypeMapper, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLocalName, LuaLocalStat, LuaSyntaxKind, LuaSyntaxToken, @@ -34,8 +34,8 @@ pub struct HoverBuilder<'a> { pub detail_render_level: RenderLevel, pub is_completion: bool, - // 默认的泛型替换器 - pub substitutor: Option, + // 默认的泛型 mapper + pub mapper: Option, } impl<'a> HoverBuilder<'a> { @@ -52,8 +52,8 @@ impl<'a> HoverBuilder<'a> { RenderLevel::Detailed }; - let substitutor = if let Some(token) = token.clone() { - infer_substitutor_base_type(semantic_model, token) + let mapper = if let Some(token) = token.clone() { + infer_mapper_base_type(semantic_model, token) } else { None }; @@ -70,7 +70,7 @@ impl<'a> HoverBuilder<'a> { type_expansion: None, tag_content: None, detail_render_level, - substitutor, + mapper, } } @@ -293,11 +293,11 @@ impl<'a> HoverBuilder<'a> { } } -// 推断基础泛型替换器 -fn infer_substitutor_base_type( +// 推断基础泛型 mapper +fn infer_mapper_base_type( semantic_model: &SemanticModel, trigger_token: LuaSyntaxToken, -) -> Option { +) -> Option { let parent = trigger_token.parent()?; match parent.kind().into() { LuaSyntaxKind::LocalName => { @@ -312,7 +312,7 @@ fn infer_substitutor_base_type( for (index, name) in local_name_list.iter().enumerate() { if target_local_name == *name { let value_expr = value_expr_list.get(index)?; - return substitutor_form_expr(semantic_model, value_expr); + return mapper_from_expr(semantic_model, value_expr); } } } @@ -325,20 +325,13 @@ fn infer_substitutor_base_type( None } -pub fn substitutor_form_expr( - semantic_model: &SemanticModel, - expr: &LuaExpr, -) -> Option { +pub fn mapper_from_expr(semantic_model: &SemanticModel, expr: &LuaExpr) -> Option { if let LuaExpr::IndexExpr(index_expr) = expr { let prefix_type = semantic_model .infer_expr(index_expr.get_prefix_expr()?) .ok()?; - let mut substitutor = TypeSubstitutor::new(); if let LuaType::Generic(generic) = prefix_type { - for (i, param) in generic.get_params().iter().enumerate() { - substitutor.insert_type(GenericTplId::Type(i as u32), param.clone(), true); - } - return Some(substitutor); + return Some(TypeMapper::from_type_array(generic.get_params().clone())); } else { return None; } diff --git a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs index 4e5cef58f..b2bcfd1bb 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs @@ -1,7 +1,7 @@ use emmylua_code_analysis::{ DbIndex, InFiled, LuaCompilation, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignatureId, LuaType, - LuaTypeDeclId, RenderLevel, SemanticModel, TypeSubstitutor, + LuaTypeDeclId, RenderLevel, SemanticModel, TypeMapper, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind}; use lsp_types::{ @@ -566,8 +566,8 @@ fn build_generic_signature_help( let generic_params = generic.get_params(); let type_decl_id = generic.get_base_type_id_ref(); if let Some(type_decl) = db.get_type_index().get_type_decl(type_decl_id) { - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); - if let Some(LuaType::DocFunction(f)) = type_decl.get_alias_origin(db, Some(&substitutor)) { + let mapper = TypeMapper::from_type_array(generic_params.clone()); + if let Some(LuaType::DocFunction(f)) = type_decl.get_alias_origin(db, Some(&mapper)) { let semantic_id = LuaSemanticDeclId::TypeDecl(type_decl_id.clone()); let description = db .get_property_index() diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs index 8132c2f3e..c7f6c9066 100644 --- a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -28,19 +28,6 @@ mod tests { result } - fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { - let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); - for i in 0..600 { - let table_key = 3_121_212; - let field_key = 1_111_112 + i; - content.push_str(&format!( - "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", - i + 1, - )); - } - content - } - #[gtest] fn test_1() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -148,7 +135,22 @@ m.foo() Ok(()) } + #[cfg(feature = "full-test")] + fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { + let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); + for i in 0..600 { + let table_key = 3_121_212; + let field_key = 1_111_112 + i; + content.push_str(&format!( + "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", + i + 1, + )); + } + content + } + #[gtest] + #[cfg(feature = "full-test")] fn test_issue_1028_i18n_semantic_tokens_repeated_prefix_guard_chain() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); let content = make_issue_1028_repeated_prefix_guard_chain_content();