diff --git a/mypy/applytype.py b/mypy/applytype.py index c8003795ba0b1..04d87686a7962 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -171,11 +171,17 @@ def apply_generic_arguments( assert isinstance(typ, TypeVarLikeType) remaining_tvars.append(typ) + instance_type = None + if callable.instance_type is not None: + instance_type = expand_type(callable.instance_type, id_to_type) + assert isinstance(instance_type, ProperType) + return callable.copy_modified( ret_type=expand_type(callable.ret_type, id_to_type), variables=remaining_tvars, type_guard=type_guard, type_is=type_is, + instance_type=instance_type, ) diff --git a/mypy/cache.py b/mypy/cache.py index b9cd8ad7a9050..e90c933fdab9a 100644 --- a/mypy/cache.py +++ b/mypy/cache.py @@ -69,7 +69,7 @@ from mypy_extensions import u8 # High-level cache layout format -CACHE_VERSION: Final = 8 +CACHE_VERSION: Final = 9 # Type used internally to represent errors: # (path, line, column, end_line, end_column, severity, message, code) @@ -558,6 +558,20 @@ def write_json(data: WriteBuffer, value: dict[str, Any]) -> None: write_json_value(data, value[key]) +def write_flags(data: WriteBuffer, flags: list[bool]) -> None: + assert len(flags) <= 26, "This many flags not supported yet" + packed = 0 + for i, flag in enumerate(flags): + if flag: + packed |= 1 << i + write_int(data, packed) + + +def read_flags(data: ReadBuffer, num_flags: int) -> list[bool]: + packed = read_int(data) + return [(packed & (1 << i)) != 0 for i in range(num_flags)] + + def write_errors(data: WriteBuffer, errs: list[ErrorTuple]) -> None: write_tag(data, LIST_GEN) write_int_bare(data, len(errs)) diff --git a/mypy/checker.py b/mypy/checker.py index 58b7fedf55f20..cb443d72983b4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5515,7 +5515,7 @@ def check_except_handler_test(self, n: Expression, is_star: bool) -> Type: if not item.is_type_obj(): self.fail(message_registry.INVALID_EXCEPTION_TYPE, n) return self.default_exception_type(is_star) - exc_type = erase_typevars(item.ret_type) + exc_type = erase_typevars(item.get_instance_type()) elif isinstance(ttype, TypeType): exc_type = ttype.item else: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 123c5f821ed29..a3f389c54546f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -689,7 +689,7 @@ def method_fullname(self, object_type: Type, method_name: str) -> str | None: # For class method calls, object_type is a callable representing the class object. # We "unwrap" it to a regular type, as the class/instance method difference doesn't # affect the fully qualified name. - object_type = get_proper_type(object_type.ret_type) + object_type = object_type.get_instance_type() elif isinstance(object_type, TypeType): object_type = object_type.item @@ -717,9 +717,9 @@ def always_returns_none(self, node: Expression) -> bool: if isinstance(typ, Instance): info = typ.type elif isinstance(typ, CallableType) and typ.is_type_obj(): - ret_type = get_proper_type(typ.ret_type) - if isinstance(ret_type, Instance): - info = ret_type.type + instance_type = typ.get_instance_type(force_fallback=True) + if isinstance(instance_type, Instance): + info = instance_type.type else: return False else: @@ -1667,9 +1667,10 @@ def check_callable_call( callee = callee.with_unpacked_kwargs().with_normalized_var_args() if callable_name is None and callee.name: callable_name = callee.name - ret_type = get_proper_type(callee.ret_type) - if callee.is_type_obj() and isinstance(ret_type, Instance): - callable_name = ret_type.type.fullname + if callee.is_type_obj(): + instance_type = callee.get_instance_type(force_fallback=True) + if isinstance(instance_type, Instance): + callable_name = instance_type.type.fullname if isinstance(callable_node, RefExpr) and callable_node.fullname in ENUM_BASES: # An Enum() call that failed SemanticAnalyzerPass2.check_enum_call(). return callee.ret_type, callee diff --git a/mypy/checkmember.py b/mypy/checkmember.py index b5dcf94a0b206..e75a8ed7a5b03 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -407,15 +407,8 @@ def validate_super_call(node: FuncBase, mx: MemberContext) -> None: def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: MemberContext) -> Type: # Class attribute. # TODO super? - ret_type = typ.items[0].ret_type - assert isinstance(ret_type, ProperType) - if isinstance(ret_type, TupleType): - ret_type = tuple_fallback(ret_type) - if isinstance(ret_type, TypedDictType): - ret_type = ret_type.fallback - if isinstance(ret_type, LiteralType): - ret_type = ret_type.fallback - if isinstance(ret_type, Instance): + instance_type = typ.items[0].get_instance_type(force_fallback=True) + if isinstance(instance_type, Instance): if not mx.is_operator: # When Python sees an operator (eg `3 == 4`), it automatically translates that # into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an @@ -432,14 +425,18 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member # See https://github.com/python/mypy/pull/1787 for more info. # TODO: do not rely on same type variables being present in all constructor overloads. result = analyze_class_attribute_access( - ret_type, name, mx, original_vars=typ.items[0].variables, mcs_fallback=typ.fallback + instance_type, + name, + mx, + original_vars=typ.items[0].variables, + mcs_fallback=typ.fallback, ) if result: return result # Look up from the 'type' type. return _analyze_member_access(name, typ.fallback, mx) else: - assert False, f"Unexpected type {ret_type!r}" + assert False, f"Unexpected type {instance_type!r}" def analyze_type_type_member_access( @@ -721,7 +718,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type: dunder_get_type = expand_type_by_instance(bound_method, typ) if isinstance(instance_type, FunctionLike) and instance_type.is_type_obj(): - owner_type = instance_type.items[0].ret_type + owner_type = instance_type.items[0].get_instance_type() instance_type = NoneType() elif isinstance(instance_type, TypeType): owner_type = instance_type.item diff --git a/mypy/constraints.py b/mypy/constraints.py index df79fdae5456c..f00a0175966f6 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -275,7 +275,11 @@ def infer_constraints_for_callable( def infer_constraints( - template: Type, actual: Type, direction: int, skip_neg_op: bool = False + template: Type, + actual: Type, + direction: int, + skip_neg_op: bool = False, + erase_types: bool = True, ) -> list[Constraint]: """Infer type constraints. @@ -312,14 +316,14 @@ def infer_constraints( # Return early on an empty branch. return [] type_state.inferring.append((template, actual)) - res = _infer_constraints(template, actual, direction, skip_neg_op) + res = _infer_constraints(template, actual, direction, skip_neg_op, erase_types) type_state.inferring.pop() return res - return _infer_constraints(template, actual, direction, skip_neg_op) + return _infer_constraints(template, actual, direction, skip_neg_op, erase_types) def _infer_constraints( - template: Type, actual: Type, direction: int, skip_neg_op: bool + template: Type, actual: Type, direction: int, skip_neg_op: bool, erase_types: bool ) -> list[Constraint]: orig_template = template template = get_proper_type(template) @@ -424,7 +428,7 @@ def _infer_constraints( return [] # Remaining cases are handled by ConstraintBuilderVisitor. - return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op)) + return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op, erase_types)) def _is_type_type(tp: ProperType) -> TypeGuard[TypeType | UnionType]: @@ -659,7 +663,9 @@ class ConstraintBuilderVisitor(TypeVisitor[list[Constraint]]): # TODO: The value may be None. Is that actually correct? actual: ProperType - def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None: + def __init__( + self, actual: ProperType, direction: int, skip_neg_op: bool, erase_types: bool + ) -> None: # Direction must be SUBTYPE_OF or SUPERTYPE_OF. self.actual = actual self.direction = direction @@ -667,6 +673,10 @@ def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> Non # this is used to prevent infinite recursion when both template and actual are # generic callables. self.skip_neg_op = skip_neg_op + # Normally we should erase generic actual type when inferring against type[T] + # to avoid leaking type variables, see testGenericClassAsArgumentToType. + # The only exception is self-types in generic classes, where we set this to False. + self.erase_types = erase_types # Trivial leaf types @@ -759,13 +769,11 @@ def visit_instance(self, template: Instance) -> list[Constraint]: and template.type.is_protocol and self.direction == SUPERTYPE_OF ): - ret_type = get_proper_type(actual.ret_type) - if isinstance(ret_type, TupleType): - ret_type = mypy.typeops.tuple_fallback(ret_type) - if isinstance(ret_type, Instance): + instance_type = actual.get_instance_type(force_fallback=True) + if isinstance(instance_type, Instance): res.extend( self.infer_constraints_from_protocol_members( - ret_type, template, ret_type, template, class_obj=True + instance_type, template, instance_type, template, class_obj=True ) ) actual = actual.fallback @@ -1377,8 +1385,18 @@ def visit_overloaded(self, template: Overloaded) -> list[Constraint]: def visit_type_type(self, template: TypeType) -> list[Constraint]: if isinstance(self.actual, CallableType): + if self.actual.is_type_obj(): + instance_type = self.actual.get_instance_type() + if self.erase_types: + instance_type = erase_typevars(instance_type) + return infer_constraints(template.item, instance_type, self.direction) return infer_constraints(template.item, self.actual.ret_type, self.direction) elif isinstance(self.actual, Overloaded): + if self.actual.is_type_obj(): + instance_type = self.actual.items[0].get_instance_type() + if self.erase_types: + instance_type = erase_typevars(instance_type) + return infer_constraints(template.item, instance_type, self.direction) return infer_constraints(template.item, self.actual.items[0].ret_type, self.direction) elif isinstance(self.actual, TypeType): return infer_constraints(template.item, self.actual.item, self.direction) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 5790b717172ac..967206a0b4f4d 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -493,11 +493,16 @@ def visit_callable_type(self, t: CallableType) -> CallableType: arg_types = self.interpolate_args_for_unpack(t, var_arg.typ) else: arg_types = self.expand_types(t.arg_types) + instance_type = None + if t.instance_type is not None: + instance_type = t.instance_type.accept(self) + assert isinstance(instance_type, ProperType) expanded = t.copy_modified( arg_types=arg_types, ret_type=t.ret_type.accept(self), - type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), - type_is=(t.type_is.accept(self) if t.type_is is not None else None), + type_guard=t.type_guard.accept(self) if t.type_guard is not None else None, + type_is=t.type_is.accept(self) if t.type_is is not None else None, + instance_type=instance_type, ) if needs_normalization: return expanded.with_normalized_var_args() diff --git a/mypy/fixup.py b/mypy/fixup.py index c0782610e8f40..48ed7c26d57ba 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -279,6 +279,8 @@ def visit_callable_type(self, ct: CallableType) -> None: ct.type_guard.accept(self) if ct.type_is is not None: ct.type_is.accept(self) + if ct.instance_type is not None: + ct.instance_type.accept(self) def visit_overloaded(self, t: Overloaded) -> None: for ct in t.items: diff --git a/mypy/indirection.py b/mypy/indirection.py index c5f3fa89b8c4a..6bbda859de8f9 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -134,6 +134,12 @@ def visit_instance(self, t: types.Instance) -> None: def visit_callable_type(self, t: types.CallableType) -> None: self._visit_type_list(t.arg_types) self._visit(t.ret_type) + if t.type_guard is not None: + self._visit(t.type_guard) + if t.type_is is not None: + self._visit(t.type_is) + if t.instance_type is not None: + self._visit(t.instance_type) self._visit_type_tuple(t.variables) def visit_overloaded(self, t: types.Overloaded) -> None: diff --git a/mypy/infer.py b/mypy/infer.py index 56f4af753db82..2c155ee2456b3 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -70,8 +70,11 @@ def infer_type_arguments( actual: Type, is_supertype: bool = False, skip_unsatisfied: bool = False, + erase_types: bool = True, ) -> list[Type | None]: # Like infer_function_type_arguments, but only match a single type # against a generic type. - constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) + constraints = infer_constraints( + template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF, erase_types=erase_types + ) return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0] diff --git a/mypy/join.py b/mypy/join.py index a8c9910e60bb7..3b6c9cc23f6f3 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -773,10 +773,14 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType: fallback = t.fallback else: fallback = s.fallback + instance_type = None + if t.instance_type is not None and s.instance_type is not None: + instance_type = join_types(t.instance_type, s.instance_type) return t.copy_modified( arg_types=arg_types, arg_names=combine_arg_names(t, s), ret_type=join_types(t.ret_type, s.ret_type), + instance_type=instance_type, fallback=fallback, name=None, ) @@ -827,10 +831,14 @@ def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType: fallback = t.fallback else: fallback = s.fallback + instance_type = None + if t.instance_type is not None and s.instance_type is not None: + instance_type = join_types(t.instance_type, s.instance_type) return t.copy_modified( arg_types=arg_types, arg_names=combine_arg_names(t, s), ret_type=join_types(t.ret_type, s.ret_type), + instance_type=instance_type, fallback=fallback, name=None, ) diff --git a/mypy/meet.py b/mypy/meet.py index cb8ad75f6013d..18b2732c55932 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -973,7 +973,7 @@ def visit_callable_type(self, t: CallableType) -> ProperType: return result elif isinstance(self.s, TypeType) and t.is_type_obj() and not t.is_generic(): # In this case we are able to potentially produce a better meet. - res = meet_types(self.s.item, t.ret_type) + res = meet_types(self.s.item, t.get_instance_type()) if not isinstance(res, (NoneType, UninhabitedType)): return TypeType.make_normalized(res) return self.default(self.s) @@ -1182,9 +1182,16 @@ def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType: fallback = t.fallback else: fallback = s.fallback + if t.instance_type is None: + instance_type = s.instance_type + elif s.instance_type is None: + instance_type = t.instance_type + else: + instance_type = meet_types(t.instance_type, s.instance_type) return t.copy_modified( arg_types=arg_types, ret_type=meet_types(t.ret_type, s.ret_type), + instance_type=instance_type, fallback=fallback, name=None, ) diff --git a/mypy/messages.py b/mypy/messages.py index 3de66c7c6082c..3be2c1671ce58 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2231,13 +2231,11 @@ def report_protocol_problems( subtype = subtype.item elif isinstance(subtype, CallableType): if subtype.is_type_obj(): - ret_type = get_proper_type(subtype.ret_type) - if isinstance(ret_type, TupleType): - ret_type = ret_type.partial_fallback - if not isinstance(ret_type, Instance): + instance_type = subtype.get_instance_type(force_fallback=True) + if not isinstance(instance_type, Instance): return class_obj = True - subtype = ret_type + subtype = instance_type else: subtype = subtype.fallback skip = ["__call__"] @@ -2827,9 +2825,7 @@ def format_literal_value(typ: LiteralType) -> str: elif isinstance(typ, FunctionLike): func = typ if func.is_type_obj(): - # The type of a type object type can be derived from the - # return type (this always works). - return format(TypeType.make_normalized(func.items[0].ret_type)) + return format(TypeType.make_normalized(func.items[0].get_instance_type())) elif isinstance(func, CallableType): if func.type_guard is not None: return_type = f"TypeGuard[{format(func.type_guard)}]" diff --git a/mypy/nodes.py b/mypy/nodes.py index 3dafffa5570dd..5147bef7d1646 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -49,6 +49,7 @@ WriteBuffer, read_bool, read_bytes, + read_flags, read_int, read_int_list, read_int_opt, @@ -61,6 +62,7 @@ read_tag, write_bool, write_bytes, + write_flags, write_int, write_int_list, write_int_opt, @@ -5222,20 +5224,6 @@ def set_flags(node: Node, flags: list[str]) -> None: setattr(node, name, True) -def write_flags(data: WriteBuffer, flags: list[bool]) -> None: - assert len(flags) <= 26, "This many flags not supported yet" - packed = 0 - for i, flag in enumerate(flags): - if flag: - packed |= 1 << i - write_int(data, packed) - - -def read_flags(data: ReadBuffer, num_flags: int) -> list[bool]: - packed = read_int(data) - return [(packed & (1 << i)) != 0 for i in range(num_flags)] - - def get_member_expr_fullname(expr: MemberExpr) -> str | None: """Return the qualified name representation of a member expression. diff --git a/mypy/plugins/singledispatch.py b/mypy/plugins/singledispatch.py index a513b91ff309b..9a5576c17e82c 100644 --- a/mypy/plugins/singledispatch.py +++ b/mypy/plugins/singledispatch.py @@ -126,7 +126,7 @@ def singledispatch_register_callback(ctx: MethodContext) -> Type: # is_subtype doesn't work when the right type is Overloaded, so we need the # actual type - register_type = first_arg_type.items[0].ret_type + register_type = first_arg_type.items[0].get_instance_type() type_args = RegisterCallableInfo(register_type, ctx.type) register_callable = make_fake_register_class_instance(ctx.api, type_args) return register_callable diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 9bbc3077ec512..ecff546049f92 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -470,6 +470,9 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem: typ.is_ellipsis_args, snapshot_types(typ.variables), typ.is_bound, + snapshot_optional_type(typ.type_guard), + snapshot_optional_type(typ.type_is), + snapshot_optional_type(typ.instance_type), ) def normalize_callable_variables(self, typ: CallableType) -> CallableType: diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index aaf388b6665d6..075bf7cb540bf 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -452,6 +452,12 @@ def visit_callable_type(self, typ: CallableType) -> None: # Fallback can be None for callable types that haven't been semantically analyzed. if typ.fallback is not None: typ.fallback.accept(self) + if typ.type_guard is not None: + typ.type_guard.accept(self) + if typ.type_is is not None: + typ.type_is.accept(self) + if typ.instance_type is not None: + typ.instance_type.accept(self) for tv in typ.variables: if isinstance(tv, TypeVarType): tv.upper_bound.accept(self) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index ba622329665ea..b2c91d8db4888 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -1003,6 +1003,12 @@ def visit_callable_type(self, typ: CallableType) -> list[str]: for arg in typ.arg_types: triggers.extend(self.get_type_triggers(arg)) triggers.extend(self.get_type_triggers(typ.ret_type)) + if typ.type_guard is not None: + triggers.extend(self.get_type_triggers(typ.type_guard)) + if typ.type_is is not None: + triggers.extend(self.get_type_triggers(typ.type_is)) + if typ.instance_type is not None: + triggers.extend(self.get_type_triggers(typ.instance_type)) # fallback is a metaclass type for class objects, and is # processed separately. return triggers diff --git a/mypy/subtypes.py b/mypy/subtypes.py index b8e8d5e3b79df..102052b65c7bd 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -749,17 +749,15 @@ def visit_callable_type(self, left: CallableType) -> bool: if is_protocol_implementation(left.fallback, right, skip=["__call__"]): return True if right.type.is_protocol and left.is_type_obj(): - ret_type = get_proper_type(left.ret_type) - if isinstance(ret_type, TupleType): - ret_type = mypy.typeops.tuple_fallback(ret_type) - if isinstance(ret_type, Instance) and is_protocol_implementation( - ret_type, right, proper_subtype=self.proper_subtype, class_obj=True + instance_type = left.get_instance_type(force_fallback=True) + if isinstance(instance_type, Instance) and is_protocol_implementation( + instance_type, right, proper_subtype=self.proper_subtype, class_obj=True ): return True return self._is_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and self._is_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_subtype(left.get_instance_type(), right.item) else: return False diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 1b38481ba0004..c408f505d61f3 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -34,6 +34,7 @@ ParamSpecType, PartialType, PlaceholderType, + ProperType, RawExpressionType, TupleType, Type, @@ -254,10 +255,15 @@ def visit_unpack_type(self, t: UnpackType, /) -> Type: return UnpackType(t.type.accept(self)) def visit_callable_type(self, t: CallableType, /) -> Type: + instance_type = None + if t.instance_type is not None: + instance_type = t.instance_type.accept(self) + assert isinstance(instance_type, ProperType) return t.copy_modified( arg_types=self.translate_type_list(t.arg_types), ret_type=t.ret_type.accept(self), variables=self.translate_variables(t.variables), + instance_type=instance_type, ) def visit_tuple_type(self, t: TupleType, /) -> Type: @@ -415,7 +421,11 @@ def visit_instance(self, t: Instance, /) -> T: def visit_callable_type(self, t: CallableType, /) -> T: # FIX generics - return self.query_types(t.arg_types + [t.ret_type]) + types = t.arg_types + [t.ret_type] + # Avoid double-counting when using queries in reports. + if t.instance_type is not None and t.instance_type != t.ret_type: + types.append(t.instance_type) + return self.query_types(types) def visit_tuple_type(self, t: TupleType, /) -> T: return self.query_types([t.partial_fallback] + t.items) @@ -551,12 +561,11 @@ def visit_instance(self, t: Instance, /) -> bool: def visit_callable_type(self, t: CallableType, /) -> bool: # FIX generics # Avoid allocating any objects here as an optimization. - args = self.query_types(t.arg_types) - ret = t.ret_type.accept(self) + inst = t.instance_type.accept(self) if t.instance_type is not None else False if self.strategy == ANY_STRATEGY: - return args or ret + return self.query_types(t.arg_types) or t.ret_type.accept(self) or inst else: - return args and ret + return self.query_types(t.arg_types) and t.ret_type.accept(self) and inst def visit_tuple_type(self, t: TupleType, /) -> bool: return self.query_types([t.partial_fallback] + t.items) diff --git a/mypy/typeops.py b/mypy/typeops.py index e13d6dd0b1732..1312fbd93b251 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -217,7 +217,9 @@ def type_object_type( is_bound=True, fallback=instance_cache.function_type, ) - result: FunctionLike = class_callable(sig, info, fallback, None, is_new=False) + result: FunctionLike = class_callable( + sig, info, None, fallback, None, is_new=False + ) if allow_cache and state.strict_optional: info.type_object_type = result return result @@ -299,19 +301,24 @@ def type_object_type_from_function( special_sig = "dict" if isinstance(signature, CallableType): - return class_callable(signature, info, fallback, special_sig, is_new, orig_self_types[0]) + return class_callable( + signature, info, def_info, fallback, special_sig, is_new, orig_self_types[0] + ) else: # Overloaded __init__/__new__. assert isinstance(signature, Overloaded) items: list[CallableType] = [] for item, orig_self in zip(signature.items, orig_self_types): - items.append(class_callable(item, info, fallback, special_sig, is_new, orig_self)) + items.append( + class_callable(item, info, def_info, fallback, special_sig, is_new, orig_self) + ) return Overloaded(items) def class_callable( init_type: CallableType, info: TypeInfo, + def_info: TypeInfo | None, type_type: Instance, special_sig: str | None, is_new: bool, @@ -329,28 +336,42 @@ def class_callable( default_ret_type = fill_typevars(info) explicit_type = init_ret_type if is_new else orig_self_type if ( + is_new + and explicit_type is not None + # We used to only use the explicit return type of __new__() when it was a subtype + # of the current class. As a result, we may now have a situation like this: + # class C: + # def __new__(cls) -> C: ... + # class D(C): ... + # So we need to ignore the explicit annotation when creating constructor type for D. + and ( + def_info is info + and not isinstance(explicit_type, AnyType) + or not is_subtype(default_ret_type, explicit_type, ignore_type_params=True) + ) + ): + ret_type = explicit_type + elif ( isinstance(explicit_type, (Instance, TupleType, UninhabitedType, LiteralType)) # We have to skip protocols, because it can be a subtype of a return type # by accident. Like `Hashable` is a subtype of `object`. See #11799 and isinstance(default_ret_type, Instance) and not default_ret_type.type.is_protocol - # Only use the declared return type from __new__ or declared self in __init__ - # if it is actually returning a subtype of what we would return otherwise. + # Use the declared self in __init__ if it is a subtype of what we would use otherwise. and is_subtype(explicit_type, default_ret_type, ignore_type_params=True) ): - ret_type: Type = explicit_type + ret_type = explicit_type else: ret_type = default_ret_type - callable_type = init_type.copy_modified( + return init_type.copy_modified( ret_type=ret_type, fallback=type_type, - name=None, + name=info.name, variables=variables, special_sig=special_sig, + instance_type=default_ret_type, ) - c = callable_type.with_name(info.name) - return c def map_type_from_supertype(typ: Type, sub_info: TypeInfo, super_info: TypeInfo) -> Type: @@ -474,7 +495,7 @@ class B(A): pass # Solve for these type arguments using the actual class or instance type. typeargs = infer_type_arguments( - self_vars, self_param_type, original_type, is_supertype=True + self_vars, self_param_type, original_type, is_supertype=True, erase_types=False ) if ( is_classmethod @@ -483,7 +504,11 @@ class B(A): pass ): # In case we call a classmethod through an instance x, fallback to type(x). typeargs = infer_type_arguments( - self_vars, self_param_type, TypeType(original_type), is_supertype=True + self_vars, + self_param_type, + TypeType(original_type), + is_supertype=True, + erase_types=False, ) # Update the method signature with the solutions found. diff --git a/mypy/types.py b/mypy/types.py index 40c3839e2efca..725320f8f38ec 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -37,6 +37,7 @@ Tag, WriteBuffer, read_bool, + read_flags, read_int, read_int_list, read_literal, @@ -46,6 +47,7 @@ read_str_opt_list, read_tag, write_bool, + write_flags, write_int, write_int_list, write_literal, @@ -2139,7 +2141,6 @@ class CallableType(FunctionLike): "arg_types", # Types of function arguments "arg_kinds", # ARG_ constants "arg_names", # Argument names; None if not a keyword argument - "min_args", # Minimum number of arguments; derived from arg_kinds "ret_type", # Return value type "name", # Name (may be None; for error messages and plugins) "definition", # For error messages. May be None. @@ -2159,6 +2160,8 @@ class CallableType(FunctionLike): # (this is used for error messages) "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? + "instance_type", # Real underlying type of a type object. This is different from + # ret_type in case we have e.g. a custom __new__() return annotation. ) def __init__( @@ -2184,6 +2187,7 @@ def __init__( from_concatenate: bool = False, imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, + instance_type: ProperType | None = None, ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -2195,7 +2199,6 @@ def __init__( # See testParamSpecJoin, that relies on passing e.g `P.args` as plain argument. self.arg_kinds = arg_kinds self.arg_names = list(arg_names) - self.min_args = arg_kinds.count(ARG_POS) self.ret_type = ret_type self.fallback = fallback assert not name or " CT: modified = CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -2272,6 +2277,7 @@ def copy_modified( else self.imprecise_arg_kinds ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, + instance_type=instance_type if instance_type is not _dummy else self.instance_type, ) # Optimization: Only NewTypes are supported as subtypes since # the class is effectively final, so we can use a cast safely. @@ -2291,6 +2297,10 @@ def kw_arg(self) -> FormalArgument | None: return FormalArgument(None, position, type, False) return None + @property + def min_args(self) -> int: + return self.arg_kinds.count(ARG_POS) + @property def is_var_arg(self) -> bool: """Does this callable have a *args argument?""" @@ -2306,9 +2316,23 @@ def is_type_obj(self) -> bool: get_proper_type(self.ret_type), UninhabitedType ) - def type_object(self) -> mypy.nodes.TypeInfo: + def get_instance_type(self, *, force_fallback: bool = False) -> ProperType: + """Get underlying type of a type object. + + By default, this will return a precise self-type, essentially whatever is + returned by fill_typevars(). Most notably this is a TupleType for named tuples. + If an Instance fallback is required, use force_fallback=True. + """ assert self.is_type_obj() - ret = get_proper_type(self.ret_type) + if self.instance_type is not None: + ret = self.instance_type + else: + # Fall back to historic behavior in case instance_type is not set. This + # will avoid crashes on type objects generated by plugins, and on (unknown) + # corner cases where is_type_obj() may "accidentally" return True. + ret = get_proper_type(self.ret_type) + if not force_fallback: + return ret if isinstance(ret, TypeVarType): ret = get_proper_type(ret.upper_bound) if isinstance(ret, TupleType): @@ -2317,15 +2341,19 @@ def type_object(self) -> mypy.nodes.TypeInfo: ret = ret.fallback if isinstance(ret, LiteralType): ret = ret.fallback - assert isinstance(ret, Instance) - return ret.type + return ret + + def type_object(self) -> mypy.nodes.TypeInfo: + instance_type = self.get_instance_type(force_fallback=True) + assert isinstance(instance_type, Instance) + return instance_type.type def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_callable_type(self) def with_name(self, name: str) -> CallableType: """Return a copy of this type with the specified name.""" - return self.copy_modified(ret_type=self.ret_type, name=name) + return self.copy_modified(name=name) def get_name(self) -> str | None: return self.name @@ -2583,10 +2611,13 @@ def serialize(self) -> JsonDict: "implicit": self.implicit, "is_bound": self.is_bound, "type_guard": self.type_guard.serialize() if self.type_guard is not None else None, - "type_is": (self.type_is.serialize() if self.type_is is not None else None), + "type_is": self.type_is.serialize() if self.type_is is not None else None, "from_concatenate": self.from_concatenate, "imprecise_arg_kinds": self.imprecise_arg_kinds, "unpack_kwargs": self.unpack_kwargs, + "instance_type": ( + self.instance_type.serialize() if self.instance_type is not None else None + ), } @classmethod @@ -2607,35 +2638,56 @@ def deserialize(cls, data: JsonDict) -> CallableType: type_guard=( deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None ), - type_is=(deserialize_type(data["type_is"]) if data["type_is"] is not None else None), + type_is=deserialize_type(data["type_is"]) if data["type_is"] is not None else None, from_concatenate=data["from_concatenate"], imprecise_arg_kinds=data["imprecise_arg_kinds"], unpack_kwargs=data["unpack_kwargs"], + instance_type=( + cast(ProperType, deserialize_type(data["instance_type"])) + if data["instance_type"] is not None + else None + ), ) def write(self, data: WriteBuffer) -> None: write_tag(data, CALLABLE_TYPE) self.fallback.write(data) + write_type_opt(data, self.instance_type) + write_flags( + data, + [ + self.is_ellipsis_args, + self.implicit, + self.is_bound, + self.from_concatenate, + self.imprecise_arg_kinds, + self.unpack_kwargs, + ], + ) write_type_list(data, self.arg_types) write_int_list(data, [int(x.value) for x in self.arg_kinds]) write_str_opt_list(data, self.arg_names) self.ret_type.write(data) write_str_opt(data, self.name) write_type_list(data, self.variables) - write_bool(data, self.is_ellipsis_args) - write_bool(data, self.implicit) - write_bool(data, self.is_bound) write_type_opt(data, self.type_guard) write_type_opt(data, self.type_is) - write_bool(data, self.from_concatenate) - write_bool(data, self.imprecise_arg_kinds) - write_bool(data, self.unpack_kwargs) write_tag(data, END_TAG) @classmethod def read(cls, data: ReadBuffer) -> CallableType: assert read_tag(data) == INSTANCE fallback = Instance.read(data) + instance_type = read_type_opt(data) + assert instance_type is None or isinstance(instance_type, ProperType) + ( + is_ellipsis_args, + implicit, + is_bound, + from_concatenate, + imprecise_arg_kinds, + unpack_kwargs, + ) = read_flags(data, num_flags=6) ret = CallableType( read_type_list(data), [ARG_KINDS[ak] for ak in read_int_list(data)], @@ -2644,14 +2696,15 @@ def read(cls, data: ReadBuffer) -> CallableType: fallback, name=read_str_opt(data), variables=read_type_var_likes(data), - is_ellipsis_args=read_bool(data), - implicit=read_bool(data), - is_bound=read_bool(data), + is_ellipsis_args=is_ellipsis_args, + implicit=implicit, + is_bound=is_bound, type_guard=read_type_opt(data), type_is=read_type_opt(data), - from_concatenate=read_bool(data), - imprecise_arg_kinds=read_bool(data), - unpack_kwargs=read_bool(data), + from_concatenate=from_concatenate, + imprecise_arg_kinds=imprecise_arg_kinds, + unpack_kwargs=unpack_kwargs, + instance_type=instance_type, ) assert read_tag(data) == END_TAG return ret diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index abd0f6bf3bdfe..2a7f41bd97f21 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -93,6 +93,9 @@ def visit_callable_type(self, t: CallableType, /) -> None: if t.type_is is not None: t.type_is.accept(self) + if t.instance_type is not None: + t.instance_type.accept(self) + def visit_tuple_type(self, t: TupleType, /) -> None: self.traverse_type_list(t.items) t.partial_fallback.accept(self) diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 5a66eff2bd3b7..9c64a4ea3a17c 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -7338,7 +7338,7 @@ class A: def __new__(cls) -> int: # E: Incompatible return type for "__new__" (returns "int", but must return a subtype of "A") pass -reveal_type(A()) # N: Revealed type is "__main__.A" +reveal_type(A()) # N: Revealed type is "builtins.int" [case testNewReturnType4] from typing import TypeVar, Type @@ -7448,6 +7448,109 @@ class MyMetaClass(type): class MyClass(metaclass=MyMetaClass): pass +[case testNewReturnType13] +from typing import Protocol + +class Foo(Protocol): + def foo(self) -> str: ... + +class A: + def __new__(cls) -> Foo: ... # E: Incompatible return type for "__new__" (returns "Foo", but must return a subtype of "A") + +reveal_type(A()) # N: Revealed type is "__main__.Foo" +reveal_type(A().foo()) # N: Revealed type is "builtins.str" + +[case testNewReturnType14] +from __future__ import annotations + +class A: + def __new__(cls) -> int: raise # E: Incompatible return type for "__new__" (returns "int", but must return a subtype of "A") + +class B(A): + @classmethod + def foo(cls) -> int: raise + +reveal_type(B.foo()) # N: Revealed type is "builtins.int" +[builtins fixtures/classmethod.pyi] + +[case testNewReturnType15] +from typing import Generic, Type, TypeVar + +T = TypeVar("T") + +class A(Generic[T]): + def __new__(cls) -> B[int]: ... + @classmethod + def foo(cls: Type[A[T]]) -> T: ... + +class B(A[T]): ... + +# The Never without error is not ideal, but matches the behavior without custom __new__(). +reveal_type(B.foo()) # N: Revealed type is "Never" +reveal_type(B[str].foo()) # N: Revealed type is "builtins.str" + +class C(A[str]): ... + +reveal_type(C.foo()) # N: Revealed type is "builtins.str" +[builtins fixtures/classmethod.pyi] + +[case testNewReturnType16] +from typing import Generic, TypeVar + +T = TypeVar("T") +class A(Generic[T]): + def __new__(cls, *args, **kwargs) -> T: # E: "__new__" must return a class instance (got "T") + ... + +class Model: + pass + +reveal_type(A[Model]()) # N: Revealed type is "__main__.Model" + +class B(A[Model]): + pass + +reveal_type(B()) # N: Revealed type is "__main__.Model" +[builtins fixtures/dict.pyi] + +[case testNewReturnType17] +class C: + def __new__(self) -> D: + return D() + +class D(C): + x: int + +C.x # E: "type[C]" has no attribute "x" + +[case testNewReturnType18] +class A: + def __new__(cls) -> A: + return A() + +class B(A): + def __new__(cls) -> A: # E: Incompatible return type for "__new__" (returns "A", but must return a subtype of "B") + return super().__new__(cls) + +class C(B): ... + +# Always respect explicit return type after giving an error. +reveal_type(B()) # N: Revealed type is "__main__.A" + +# Ignore "implicit" return type to preserve backwards compatibility. +reveal_type(C()) # N: Revealed type is "__main__.C" + +[case testNewReturnType19] +from typing import TypeVar + +T = TypeVar("T") + +def f(tp: type[T]) -> T: ... + +class C: + def __new__(cls) -> int: ... # type: ignore[misc] + +reveal_type(f(C)) # N: Revealed type is "__main__.C" [case testMetaclassPlaceholderNode] from sympy.assumptions import ManagedProperties diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index a3a5b02d54f89..2db33a05b62dd 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3688,3 +3688,12 @@ reveal_type(ok3) # N: Revealed type is "tuple[()]" bad1: list[()] = [] # E: "list" expects 1 type argument, but none given \ # E: Missing type arguments for generic type "list" [builtins fixtures/tuple.pyi] + +[case testGenericClassAsArgumentToType] +from typing import TypeVar, Generic + +T = TypeVar("T") +def test(tp: type[T]) -> T: ... + +class C(Generic[T]): ... +reveal_type(test(C)) # N: Revealed type is "__main__.C[Any]"