From 1d95c83ef27332ce7463c8fcc91d086e7eeaeebb Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 11 Aug 2025 14:30:04 +0200 Subject: [PATCH 01/12] special case TypedDict.get --- mypy/checkmember.py | 62 +++++++++++++++++++++++++++++ test-data/unit/check-typeddict.test | 22 ++++++---- test-data/unit/pythoneval.test | 11 +++-- 3 files changed, 84 insertions(+), 11 deletions(-) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 2c41f2e273cc..e384e8c5e5b3 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -68,6 +68,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, @@ -1400,6 +1401,67 @@ def analyze_typeddict_access( fallback=mx.chk.named_type("builtins.function"), name=name, ) + elif name == "get": + # synthesize TypedDict.get() overloads + t = TypeVarType( + "T", + "T", + id=TypeVarId(-1), + values=[], + upper_bound=mx.chk.named_type("builtins.object"), + default=AnyType(TypeOfAny.from_omitted_generics), + ) + str_type = mx.chk.named_type("builtins.str") + fn_type = mx.chk.named_type("builtins.function") + object_type = mx.chk.named_type("builtins.object") + + overloads: list[CallableType] = [] + # add two overloads per TypedDictType spec + for key, val in typ.items.items(): + # first overload: def(Literal[key]) -> val + no_default = CallableType( + arg_types=[LiteralType(key, fallback=str_type)], + arg_kinds=[ARG_POS], + arg_names=[None], + ret_type=val, + fallback=fn_type, + name=name, + ) + # second Overload: def [T] (Literal[key], default: T | Val, /) -> T | Val + with_default = CallableType( + variables=[t], + arg_types=[LiteralType(key, fallback=str_type), UnionType.make_union([val, t])], + arg_kinds=[ARG_POS, ARG_POS], + arg_names=[None, None], + ret_type=UnionType.make_union([val, t]), + fallback=fn_type, + name=name, + ) + overloads.append(no_default) + overloads.append(with_default) + + # finally, add fallback overloads when a key is used that is not in the TypedDict + # def (str) -> object + fallback_no_default = CallableType( + arg_types=[str_type], + arg_kinds=[ARG_POS], + arg_names=[None], + ret_type=object_type, + fallback=fn_type, + name=name, + ) + # def (str, object) -> object + fallback_with_default = CallableType( + arg_types=[str_type, object_type], + arg_kinds=[ARG_POS, ARG_POS], + arg_names=[None, None], + ret_type=object_type, + fallback=fn_type, + name=name, + ) + overloads.append(fallback_no_default) + overloads.append(fallback_with_default) + return Overloaded(overloads) return _analyze_member_access(name, typ.fallback, mx, override_info) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 34cae74d795b..86b917cb5249 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -1016,7 +1016,7 @@ class A: pass D = TypedDict('D', {'x': List[int], 'y': int}) d: D reveal_type(d.get('x', [])) # N: Revealed type is "builtins.list[builtins.int]" -d.get('x', ['x']) # E: List item 0 has incompatible type "str"; expected "int" +reveal_type(d.get('x', ['x'])) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]" a = [''] reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]" [builtins fixtures/dict.pyi] @@ -1026,14 +1026,22 @@ reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.i from typing import TypedDict D = TypedDict('D', {'x': int, 'y': str}) d: D -d.get() # E: All overload variants of "get" of "Mapping" require at least one argument \ +d.get() # E: All overload variants of "get" require at least one argument \ # N: Possible overload variants: \ - # N: def get(self, k: str) -> object \ - # N: def [V] get(self, k: str, default: object) -> object -d.get('x', 1, 2) # E: No overload variant of "get" of "Mapping" matches argument types "str", "int", "int" \ + # N: def get(Literal['x'], /) -> int \ + # N: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] \ + # N: def get(Literal['y'], /) -> str \ + # N: def [T] get(Literal['y'], Union[str, T], /) -> Union[str, T] \ + # N: def get(str, /) -> object \ + # N: def get(str, object, /) -> object +d.get('x', 1, 2) # E: No overload variant of "get" matches argument types "str", "int", "int" \ # N: Possible overload variants: \ - # N: def get(self, k: str) -> object \ - # N: def [V] get(self, k: str, default: Union[int, V]) -> object + # N: def get(Literal['x'], /) -> int \ + # N: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] \ + # N: def get(Literal['y'], /) -> str \ + # N: def [T] get(Literal['y'], Union[int, T], /) -> Union[str, T] \ + # N: def get(str, /) -> object \ + # N: def get(str, object, /) -> object x = d.get('z') reveal_type(x) # N: Revealed type is "builtins.object" s = '' diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 9b5d8a1ac54c..c1826e7c5d67 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1046,11 +1046,14 @@ reveal_type(d.get(s)) _testTypedDictGet.py:6: note: Revealed type is "Union[builtins.int, None]" _testTypedDictGet.py:7: note: Revealed type is "Union[builtins.str, None]" _testTypedDictGet.py:8: note: Revealed type is "builtins.object" -_testTypedDictGet.py:9: error: All overload variants of "get" of "Mapping" require at least one argument +_testTypedDictGet.py:9: error: All overload variants of "get" require at least one argument _testTypedDictGet.py:9: note: Possible overload variants: -_testTypedDictGet.py:9: note: def get(self, str, /) -> object -_testTypedDictGet.py:9: note: def get(self, str, /, default: object) -> object -_testTypedDictGet.py:9: note: def [_T] get(self, str, /, default: _T) -> object +_testTypedDictGet.py:9: note: def get(Literal['x'], /) -> int +_testTypedDictGet.py:9: note: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] +_testTypedDictGet.py:9: note: def get(Literal['y'], /) -> str +_testTypedDictGet.py:9: note: def [T] get(Literal['y'], Union[str, T], /) -> Union[str, T] +_testTypedDictGet.py:9: note: def get(str, /) -> object +_testTypedDictGet.py:9: note: def get(str, object, /) -> object _testTypedDictGet.py:11: note: Revealed type is "builtins.object" [case testTypedDictMappingMethods] From 76194eb591fadbd9bb0d1b4e4e2e8d492b2a226c Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Thu, 7 Aug 2025 12:53:36 +0200 Subject: [PATCH 02/12] initial draft --- mypy/meet.py | 38 +++++++++++++++++++++++++++-- mypy/subtypes.py | 13 ++++++++-- test-data/unit/check-inference.test | 4 +-- test-data/unit/check-python310.test | 4 +-- test-data/unit/check-python38.test | 16 ++++++------ 5 files changed, 59 insertions(+), 16 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index 349c15e668c3..3f59bc83a54c 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -94,6 +94,36 @@ def meet_types(s: Type, t: Type) -> ProperType: return s return t + # special casing for dealing with last known values + if is_proper_subtype(s, t, ignore_promotions=True) and is_proper_subtype( + t, s, ignore_promotions=True + ): + lkv: LiteralType | None + if s.last_known_value is None and t.last_known_value is None: + # Both types have no last known value, so we return the original type. + lkv = None + elif s.last_known_value is None and t.last_known_value is not None: + lkv = t.last_known_value + elif s.last_known_value is not None and t.last_known_value is None: + lkv = s.last_known_value + elif s.last_known_value is not None and t.last_known_value is not None: + lkv_meet = meet_types(s.last_known_value, t.last_known_value) + if isinstance(lkv_meet, UninhabitedType): + lkv = None + elif isinstance(lkv_meet, LiteralType): + lkv = lkv_meet + else: + msg = ( + f"Unexpected meet result for last known values: " + f"{s.last_known_value=} and {t.last_known_value=} " + f"resulted in {lkv_meet=}" + ) + raise ValueError(msg) + else: + assert False + assert lkv is None or isinstance(lkv, LiteralType) + return t.copy_modified(last_known_value=lkv) + if not isinstance(s, UnboundType) and not isinstance(t, UnboundType): if is_proper_subtype(s, t, ignore_promotions=True): return s @@ -1088,8 +1118,12 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType: def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType) and self.s == t: return t - elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s): - return t + elif isinstance(self.s, Instance): + if is_subtype(t.fallback, self.s): + return t + if self.s.last_known_value is not None: + return meet_types(self.s.last_known_value, t) + return self.default(self.s) else: return self.default(self.s) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..d5edcccbe1d9 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -629,8 +629,17 @@ def visit_instance(self, left: Instance) -> bool: return True if isinstance(item, Instance): return is_named_instance(item, "builtins.object") - if isinstance(right, LiteralType) and left.last_known_value is not None: - return self._is_subtype(left.last_known_value, right) + # if isinstance(right, LiteralType) and left.last_known_value is not None: + # return self._is_subtype(left.last_known_value, right) + if isinstance(right, LiteralType): + if self.proper_subtype: + # Instance types like Literal["sum"]? is *assignable* to Literal["sum"], + # but is not a proper subtype of it. (Literal["sum"]? is a gradual type, + # that is a proper subtype of str, and is assignable to Literal["sum"], + # but not a proper subtype of it.) + return False + if left.last_known_value is not None: + return self._is_subtype(left.last_known_value, right) if isinstance(right, FunctionLike): # Special case: Instance can be a subtype of Callable / Overloaded. call = find_member("__call__", left, left, is_operator=True) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 53efcc0d22e3..c101f9dc9536 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -4076,7 +4076,7 @@ def check_and(maybe: bool) -> None: bar = None if maybe and (foo := [1])[(bar := 0)]: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" - reveal_type(bar) # N: Revealed type is "builtins.int" + reveal_type(bar) # N: Revealed type is "Literal[0]?" else: reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]" reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]" @@ -4102,7 +4102,7 @@ def check_or(maybe: bool) -> None: reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]" else: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" - reveal_type(bar) # N: Revealed type is "builtins.int" + reveal_type(bar) # N: Revealed type is "Literal[0]?" def check_or_nested(maybe: bool) -> None: foo = None diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index f264167cb067..b14527963bd6 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1354,7 +1354,7 @@ m: str match m: case a if a := "test": - reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(a) # N: Revealed type is "Literal['test']?" [case testMatchNarrowingPatternGuard] m: object @@ -2686,7 +2686,7 @@ match m[k]: match 0: case 0 as i: - reveal_type(i) # N: Revealed type is "Literal[0]?" + reveal_type(i) # N: Revealed type is "Literal[0]" case int(i): i # E: Statement is unreachable case other: diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index dd3f793fd02b..8b7e8441c6b7 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -214,10 +214,10 @@ i(arg=0) # E: Unexpected keyword argument "arg" from typing import Final, NamedTuple, Optional, List if a := 2: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[2]?" while b := "x": - reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "Literal['x']?" l = [y2 := 1, y2 + 2, y2 + 3] reveal_type(y2) # N: Revealed type is "builtins.int" @@ -242,10 +242,10 @@ reveal_type(new_v) # N: Revealed type is "builtins.int" def f(x: int = (c := 4)) -> int: if a := 2: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[2]?" while b := "x": - reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "Literal['x']?" x = (y := 1) + (z := 2) reveal_type(x) # N: Revealed type is "builtins.int" @@ -284,7 +284,7 @@ def f(x: int = (c := 4)) -> int: f(x=(y7 := 3)) reveal_type(y7) # N: Revealed type is "builtins.int" - reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "builtins.int" + reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "Literal[3]?" y8 # E: Name "y8" is not defined y7 = 1.0 # E: Incompatible types in assignment (expression has type "float", variable has type "int") @@ -325,16 +325,16 @@ def check_binder(x: Optional[int], y: Optional[int], z: Optional[int], a: Option reveal_type(y) # N: Revealed type is "Union[builtins.int, None]" if x and (y := 1): - reveal_type(y) # N: Revealed type is "builtins.int" + reveal_type(y) # N: Revealed type is "Literal[1]?" if (a := 1) and x: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[1]?" if (b := 1) or x: reveal_type(b) # N: Revealed type is "builtins.int" if z := 1: - reveal_type(z) # N: Revealed type is "builtins.int" + reveal_type(z) # N: Revealed type is "Literal[1]?" def check_partial() -> None: x = None From 611ceafc29691845c9771bfa459da7399801c354 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Thu, 7 Aug 2025 16:50:17 +0200 Subject: [PATCH 03/12] fix else branch in match case with literals --- mypy/subtypes.py | 5 +++++ mypy/test/testtypes.py | 6 ++++-- mypy/typeops.py | 25 +++++++++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index d5edcccbe1d9..c609d03ea74d 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -2136,6 +2136,11 @@ def covers_at_runtime(item: Type, supertype: Type) -> bool: item = get_proper_type(item) supertype = get_proper_type(supertype) + # Use last known value for Instance types, if available. + # This ensures that e.g. Literal["max"]? is covered by Literal["max"]. + if isinstance(item, Instance) and item.last_known_value is not None: + item = item.last_known_value + # Since runtime type checks will ignore type arguments, erase the types. if not (isinstance(supertype, FunctionLike) and supertype.is_type_obj()): supertype = erase_type(supertype) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 0fe41bc28ecd..dececfb80358 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -601,7 +601,7 @@ def test_simplified_union_with_literals(self) -> None: [fx.lit1_inst, fx.lit3_inst], UnionType([fx.lit1_inst, fx.lit3_inst]) ) self.assert_simplified_union([fx.lit1_inst, fx.uninhabited], fx.lit1_inst) - self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1) + self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1_inst) self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst])) self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst])) @@ -651,7 +651,9 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), ) - self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst + ) def assert_simplified_union(self, original: list[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) diff --git a/mypy/typeops.py b/mypy/typeops.py index 88b3c5da48ce..8dd9b45c4821 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -572,6 +572,8 @@ def make_simplified_union( * [int, Any] -> Union[int, Any] (Any types are not simplified away!) * [Any, Any] -> Any * [int, Union[bytes, str]] -> Union[int, bytes, str] + * [Literal[1]?, Literal[1]] -> Literal[1]? + * Literal["max"]?, Literal["max", "sum"] -> Literal["max"]? | Literal["sum"] Note: This must NOT be used during semantic analysis, since TypeInfos may not be fully initialized. @@ -600,13 +602,32 @@ def make_simplified_union( ): simplified_set = try_contracting_literals_in_union(simplified_set) - result = get_proper_type(UnionType.make_union(simplified_set, line, column)) + # Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]? + new_items = [] + for item in simplified_set: + if isinstance(item, LiteralType): + # scan if there is an Instance with a last_known_value that matches + for other in simplified_set: + if ( + isinstance(other, Instance) + and other.last_known_value is not None + and item == other.last_known_value + ): + # do not include item + break + else: + new_items.append(item) + else: + # If the item is not a LiteralType, we can use it directly. + new_items.append(item) + + result = get_proper_type(UnionType.make_union(new_items, line, column)) nitems = len(items) if nitems > 1 and ( nitems > 2 or not (type(items[0]) is NoneType or type(items[1]) is NoneType) ): - # Step 5: At last, we erase any (inconsistent) extra attributes on instances. + # Step 6: At last, we erase any (inconsistent) extra attributes on instances. # Initialize with None instead of an empty set as a micro-optimization. The set # is needed very rarely, so we try to avoid constructing it. From 822dd30d03cf4df048459616508f80a4b556deac Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Thu, 7 Aug 2025 17:16:07 +0200 Subject: [PATCH 04/12] simplify literal elimination --- mypy/test/testtypes.py | 2 ++ mypy/typeops.py | 28 +++++++++------------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index dececfb80358..2657ed26de64 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -647,6 +647,8 @@ def test_simplified_union_with_str_instance_literals(self) -> None: def test_simplified_union_with_mixed_str_literals(self) -> None: fx = self.fx + self.assert_simplified_union([fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst) + self.assert_simplified_union( [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), diff --git a/mypy/typeops.py b/mypy/typeops.py index 8dd9b45c4821..58e860dc2951 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -603,25 +603,15 @@ def make_simplified_union( simplified_set = try_contracting_literals_in_union(simplified_set) # Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]? - new_items = [] - for item in simplified_set: - if isinstance(item, LiteralType): - # scan if there is an Instance with a last_known_value that matches - for other in simplified_set: - if ( - isinstance(other, Instance) - and other.last_known_value is not None - and item == other.last_known_value - ): - # do not include item - break - else: - new_items.append(item) - else: - # If the item is not a LiteralType, we can use it directly. - new_items.append(item) - - result = get_proper_type(UnionType.make_union(new_items, line, column)) + proper_items: list[ProperType] = list(map(get_proper_type, simplified_set)) + last_known_values: list[LiteralType | None] = [ + p_t.last_known_value if isinstance(p_t, Instance) else None for p_t in proper_items + ] + simplified_set = [ + item for item, p_t in zip(simplified_set, proper_items) if p_t not in last_known_values + ] + + result = get_proper_type(UnionType.make_union(simplified_set, line, column)) nitems = len(items) if nitems > 1 and ( From d8a6f94c15ca4c1570e7afae117673e57c221fb8 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Thu, 7 Aug 2025 17:23:07 +0200 Subject: [PATCH 05/12] use comprehension instead of map --- mypy/typeops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 58e860dc2951..a51e8745a641 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -603,7 +603,7 @@ def make_simplified_union( simplified_set = try_contracting_literals_in_union(simplified_set) # Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]? - proper_items: list[ProperType] = list(map(get_proper_type, simplified_set)) + proper_items: list[ProperType] = [get_proper_type(t) for t in simplified_set] last_known_values: list[LiteralType | None] = [ p_t.last_known_value if isinstance(p_t, Instance) else None for p_t in proper_items ] From 8a5c7dc708a2184c85fdecf0df48c8cd7307e478 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 8 Aug 2025 11:01:05 +0200 Subject: [PATCH 06/12] simplification --- mypy/meet.py | 54 +++++++++++++++++++++--------------------------- mypy/subtypes.py | 5 +---- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index 3f59bc83a54c..63a5fe54f7dd 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -81,6 +81,30 @@ def meet_types(s: Type, t: Type) -> ProperType: t = get_proper_type(t) if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type: + # special casing for dealing with last known values + lkv: LiteralType | None + + if s.last_known_value is None: + lkv = t.last_known_value + elif t.last_known_value is None: + lkv = s.last_known_value + else: + lkv_meet = meet_types(s.last_known_value, t.last_known_value) + if isinstance(lkv_meet, UninhabitedType): + lkv = None + elif isinstance(lkv_meet, LiteralType): + lkv = lkv_meet + else: + msg = ( + f"Unexpected result: " + f"meet of {s.last_known_value=!s} and {t.last_known_value=!s} " + f"resulted in {lkv_meet!s}" + ) + raise ValueError(msg) + + t = t.copy_modified(last_known_value=lkv) + s = s.copy_modified(last_known_value=lkv) + # Code in checker.py should merge any extra_items where possible, so we # should have only compatible extra_items here. We check this before # the below subtype check, so that extra_attrs will not get erased. @@ -94,36 +118,6 @@ def meet_types(s: Type, t: Type) -> ProperType: return s return t - # special casing for dealing with last known values - if is_proper_subtype(s, t, ignore_promotions=True) and is_proper_subtype( - t, s, ignore_promotions=True - ): - lkv: LiteralType | None - if s.last_known_value is None and t.last_known_value is None: - # Both types have no last known value, so we return the original type. - lkv = None - elif s.last_known_value is None and t.last_known_value is not None: - lkv = t.last_known_value - elif s.last_known_value is not None and t.last_known_value is None: - lkv = s.last_known_value - elif s.last_known_value is not None and t.last_known_value is not None: - lkv_meet = meet_types(s.last_known_value, t.last_known_value) - if isinstance(lkv_meet, UninhabitedType): - lkv = None - elif isinstance(lkv_meet, LiteralType): - lkv = lkv_meet - else: - msg = ( - f"Unexpected meet result for last known values: " - f"{s.last_known_value=} and {t.last_known_value=} " - f"resulted in {lkv_meet=}" - ) - raise ValueError(msg) - else: - assert False - assert lkv is None or isinstance(lkv, LiteralType) - return t.copy_modified(last_known_value=lkv) - if not isinstance(s, UnboundType) and not isinstance(t, UnboundType): if is_proper_subtype(s, t, ignore_promotions=True): return s diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c609d03ea74d..ef52a22b46e5 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -629,14 +629,11 @@ def visit_instance(self, left: Instance) -> bool: return True if isinstance(item, Instance): return is_named_instance(item, "builtins.object") - # if isinstance(right, LiteralType) and left.last_known_value is not None: - # return self._is_subtype(left.last_known_value, right) if isinstance(right, LiteralType): if self.proper_subtype: # Instance types like Literal["sum"]? is *assignable* to Literal["sum"], # but is not a proper subtype of it. (Literal["sum"]? is a gradual type, - # that is a proper subtype of str, and is assignable to Literal["sum"], - # but not a proper subtype of it.) + # that is a proper subtype of str, and assignable to Literal["sum"]. return False if left.last_known_value is not None: return self._is_subtype(left.last_known_value, right) From 4a69e5f1fa0940370e4d45ce0be719f9f1c80304 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 8 Aug 2025 16:22:52 +0200 Subject: [PATCH 07/12] improve testing of basic ops with last_known_value --- mypy/join.py | 29 +++++-- mypy/meet.py | 6 +- mypy/subtypes.py | 6 ++ mypy/test/testsubtypes.py | 126 +++++++++++++++++++++++++++- mypy/test/testtypes.py | 109 +++++++++++++++++++++++- test-data/unit/check-literal.test | 2 +- test-data/unit/check-python310.test | 12 ++- 7 files changed, 275 insertions(+), 15 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index 099df02680f0..aa0970d0ba10 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -69,6 +69,10 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: # Simplest case: join two types with the same base type (but # potentially different arguments). + last_known_value = ( + None if t.last_known_value != s.last_known_value else t.last_known_value + ) + # Combine type arguments. args: list[Type] = [] # N.B: We use zip instead of indexing because the lengths might have @@ -104,10 +108,10 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = join_types(ta, sa, self) if len(type_var.values) != 0 and new_type not in type_var.values: self.seen_instances.pop() - return object_from_instance(t) + return object_from_instance(t, last_known_value=last_known_value) if not is_subtype(new_type, type_var.upper_bound): self.seen_instances.pop() - return object_from_instance(t) + return object_from_instance(t, last_known_value=last_known_value) # TODO: contravariant case should use meet but pass seen instances as # an argument to keep track of recursive checks. elif type_var.variance in (INVARIANT, CONTRAVARIANT): @@ -117,7 +121,7 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = ta elif not is_equivalent(ta, sa): self.seen_instances.pop() - return object_from_instance(t) + return object_from_instance(t, last_known_value=last_known_value) else: # If the types are different but equivalent, then an Any is involved # so using a join in the contravariant case is also OK. @@ -141,7 +145,7 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = join_types(ta, sa, self) assert new_type is not None args.append(new_type) - result: ProperType = Instance(t.type, args) + result: ProperType = Instance(t.type, args, last_known_value=last_known_value) elif t.type.bases and is_proper_subtype( t, s, subtype_context=SubtypeContext(ignore_type_params=True) ): @@ -270,6 +274,10 @@ def visit_unbound_type(self, t: UnboundType) -> ProperType: def visit_union_type(self, t: UnionType) -> ProperType: if is_proper_subtype(self.s, t): return t + elif isinstance(self.s, LiteralType): + # E.g. join("x", "y" | "z") -> "x" | "y" | "z" + # and join(1, "y" | "z") -> object + return mypy.typeops.make_simplified_union(join_types(self.s, x) for x in t.items) else: return mypy.typeops.make_simplified_union([self.s, t]) @@ -621,13 +629,18 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType: def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType): if t == self.s: + # E.g. Literal["x"], Literal["x"] -> Literal["x"] return t - if self.s.fallback.type.is_enum and t.fallback.type.is_enum: + if (self.s.fallback.type == t.fallback.type) or ( + self.s.fallback.type.is_enum and t.fallback.type.is_enum + ): return mypy.typeops.make_simplified_union([self.s, t]) return join_types(self.s.fallback, t.fallback) elif isinstance(self.s, Instance) and self.s.last_known_value == t: + # E.g. Literal["x"], Literal["x"]? -> Literal["x"] return t else: + # E.g. Literal["x"], Literal["y"]? -> str return join_types(self.s, t.fallback) def visit_partial_type(self, t: PartialType) -> ProperType: @@ -848,10 +861,12 @@ def combine_arg_names( return new_names -def object_from_instance(instance: Instance) -> Instance: +def object_from_instance( + instance: Instance, last_known_value: LiteralType | None = None +) -> Instance: """Construct the type 'builtins.object' from an instance type.""" # Use the fact that 'object' is always the last class in the mro. - res = Instance(instance.type.mro[-1], []) + res = Instance(instance.type.mro[-1], [], last_known_value=last_known_value) return res diff --git a/mypy/meet.py b/mypy/meet.py index 63a5fe54f7dd..649fe201e980 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -1113,9 +1113,11 @@ def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType) and self.s == t: return t elif isinstance(self.s, Instance): - if is_subtype(t.fallback, self.s): - return t + # if is_subtype(t.fallback, self.s): + # return t if self.s.last_known_value is not None: + # meet(Literal["max"]?, Literal["max"]) -> Literal["max"] + # meet(Literal["sum"]?, Literal["max"]) -> Never return meet_types(self.s.last_known_value, t) return self.default(self.s) else: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index ef52a22b46e5..0065133402b3 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -971,6 +971,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: def visit_literal_type(self, left: LiteralType) -> bool: if isinstance(self.right, LiteralType): return left == self.right + elif ( + isinstance(self.right, Instance) + and self.right.last_known_value is not None + and self.proper_subtype + ): + return self._is_subtype(left, self.right.last_known_value) else: return self._is_subtype(left.fallback, self.right) diff --git a/mypy/test/testsubtypes.py b/mypy/test/testsubtypes.py index b75c22bca7f7..9c191ab3cfd8 100644 --- a/mypy/test/testsubtypes.py +++ b/mypy/test/testsubtypes.py @@ -1,7 +1,7 @@ from __future__ import annotations from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT -from mypy.subtypes import is_subtype +from mypy.subtypes import is_proper_subtype, is_subtype, restrict_subtype_away from mypy.test.helpers import Suite from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture from mypy.types import Instance, TupleType, Type, UninhabitedType, UnpackType @@ -277,6 +277,74 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None: def test_fallback_not_subtype_of_tuple(self) -> None: self.assert_not_subtype(self.fx.a, TupleType([self.fx.b], fallback=self.fx.a)) + def test_literal(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" ≲ str -> YES + # str ≲ "x" -> NO + # "x"? ≲ str -> YES + # str ≲ "x"? -> YES + self.assert_subtype(str1, str_type) + self.assert_not_subtype(str_type, str1) + self.assert_subtype(str1_inst, str_type) + self.assert_subtype(str_type, str1_inst) + + # other operand is the same literal + # "x" ≲ "x" -> YES + # "x" ≲ "x"? -> YES + # "x"? ≲ "x" -> YES + # "x"? ≲ "x"? -> YES + self.assert_subtype(str1, str1) + self.assert_subtype(str1, str1_inst) + self.assert_subtype(str1_inst, str1) + self.assert_subtype(str1_inst, str1_inst) + + # second operand is a different literal + # "x" ≲ "y" -> NO + # "x" ≲ "y"? -> YES + # "x"? ≲ "y" -> NO + # "x"? ≲ "y"? -> YES + self.assert_not_subtype(str1, str2) + self.assert_subtype(str1, str2_inst) + self.assert_not_subtype(str1_inst, str2) + self.assert_subtype(str1_inst, str2_inst) + + # check proper subtyping + # second operand is the fallback type + # "x" <: str -> YES + # str <: "x" -> NO + # "x"? <: str -> YES + # str <: "x"? -> YES + self.assert_proper_subtype(str1, str_type) + self.assert_not_proper_subtype(str_type, str1) + self.assert_proper_subtype(str1_inst, str_type) + self.assert_proper_subtype(str_type, str1_inst) + + # second operand is the same literal + # "x" <: "x" -> YES + # "x" <: "x"? -> YES + # "x"? <: "x" -> NO + # "x"? <: "x"? -> YES + self.assert_proper_subtype(str1, str1) + self.assert_proper_subtype(str1, str1_inst) + self.assert_not_proper_subtype(str1_inst, str1) + self.assert_proper_subtype(str1_inst, str1_inst) + + # second operand is a different literal + # "x" ≲ "y" -> NO + # "x" ≲ "y"? -> NO + # "x"? ≲ "y" -> NO + # "x"? ≲ "y"? -> YES + self.assert_not_proper_subtype(str1, str2) + self.assert_not_proper_subtype(str1, str2_inst) + self.assert_not_proper_subtype(str1_inst, str2) + self.assert_proper_subtype(str1_inst, str2_inst) + # IDEA: Maybe add these test cases (they are tested pretty well in type # checker tests already): # * more interface subtyping test cases @@ -287,6 +355,12 @@ def test_fallback_not_subtype_of_tuple(self) -> None: # * any type # * generic function types + def assert_proper_subtype(self, s: Type, t: Type) -> None: + assert is_proper_subtype(s, t), f"{s} not proper subtype of {t}" + + def assert_not_proper_subtype(self, s: Type, t: Type) -> None: + assert not is_proper_subtype(s, t), f"{s} not proper subtype of {t}" + def assert_subtype(self, s: Type, t: Type) -> None: assert is_subtype(s, t), f"{s} not subtype of {t}" @@ -304,3 +378,53 @@ def assert_equivalent(self, s: Type, t: Type) -> None: def assert_unrelated(self, s: Type, t: Type) -> None: self.assert_not_subtype(s, t) self.assert_not_subtype(t, s) + + +class RestrictionSuite(Suite): + # Tests for type restrictions "A - B", i.e. ``T <: A and not T <: B``. + + def setUp(self) -> None: + self.fx = TypeFixture() + + def assert_restriction(self, s: Type, t: Type, expected: Type) -> None: + actual = restrict_subtype_away(s, t) + msg = f"restrict_subtype_away({s}, {t}) == {{}} ({{}} expected)" + self.assertEqual(actual, expected, msg=msg.format(actual, expected)) + + def test_literal(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + uninhabited = self.fx.uninhabited + + # other operand is the fallback type + # "x" - str -> Never + # str - "x" -> str + # "x"? - str -> Never + # str - "x"? -> Never + self.assert_restriction(str1, str_type, uninhabited) + self.assert_restriction(str_type, str1, str_type) + self.assert_restriction(str1_inst, str_type, uninhabited) + self.assert_restriction(str_type, str1_inst, uninhabited) + + # other operand is the same literal + # "x" - "x" -> Never + # "x" - "x"? -> Never + # "x"? - "x" -> Never + # "x"? - "x"? -> Never + self.assert_restriction(str1, str1, uninhabited) + self.assert_restriction(str1, str1_inst, uninhabited) + self.assert_restriction(str1_inst, str1, uninhabited) + self.assert_restriction(str1_inst, str1_inst, uninhabited) + + # other operand is a different literal + # "x" - "y" -> "x" + # "x" - "y"? -> Never + # "x"? - "y" -> "x"? + # "x"? - "y"? -> Never + self.assert_restriction(str1, str2, str1) + self.assert_restriction(str1, str2_inst, uninhabited) + self.assert_restriction(str1_inst, str2, str1_inst) + self.assert_restriction(str1_inst, str2_inst, uninhabited) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 2657ed26de64..9b8c063f3c00 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -647,8 +647,6 @@ def test_simplified_union_with_str_instance_literals(self) -> None: def test_simplified_union_with_mixed_str_literals(self) -> None: fx = self.fx - self.assert_simplified_union([fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst) - self.assert_simplified_union( [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), @@ -657,6 +655,43 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst ) + def test_simplified_union_with_mixed_str_literals2(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" | str -> str + # str | "x" -> str + # "x"? | str -> str + # str | "x"? -> str + self.assert_simplified_union([str1, str_type], str_type) + self.assert_simplified_union([str_type, str1], str_type) + self.assert_simplified_union([str1_inst, str_type], str_type) + self.assert_simplified_union([str_type, str1_inst], str_type) + + # other operand is the same literal + # "x" | "x" -> "x" + # "x" | "x"? -> "x"? + # "x"? | "x" -> "x"? + # "x"? | "x"? -> "x"? + self.assert_simplified_union([str1, str1], str1) + self.assert_simplified_union([str1, str1_inst], str1_inst) + self.assert_simplified_union([str1_inst, str1], str1_inst) + self.assert_simplified_union([str1_inst, str1_inst], str1_inst) + + # other operand is a different literal + # "x" | "y" -> "x" | "y" + # "x" | "y"? -> "x" | "y"? + # "x"? | "y" -> "x"? | "y" + # "x"? | "y"? -> "x"? | "y"? + self.assert_simplified_union([str1, str2], UnionType([str1, str2])) + self.assert_simplified_union([str1, str2_inst], UnionType([str1, str2_inst])) + self.assert_simplified_union([str1_inst, str2], UnionType([str1_inst, str2])) + self.assert_simplified_union([str1_inst, str2_inst], UnionType([str1_inst, str2_inst])) + def assert_simplified_union(self, original: list[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) assert_equal(make_simplified_union(list(reversed(original))), union) @@ -992,7 +1027,8 @@ def test_literal_type(self) -> None: self.assert_join(lit1, lit1, lit1) self.assert_join(lit1, a, a) self.assert_join(lit1, d, self.fx.o) - self.assert_join(lit1, lit2, a) + self.assert_simple_join(lit1, lit2, UnionType([lit1, lit2])) + self.assert_simple_join(lit2, lit1, UnionType([lit2, lit1])) self.assert_join(lit1, lit3, self.fx.o) self.assert_join(lit1, self.fx.anyt, self.fx.anyt) self.assert_join(UnionType([lit1, lit2]), lit2, UnionType([lit1, lit2])) @@ -1015,6 +1051,40 @@ def test_literal_type(self) -> None: UnionType([lit2, lit3]), UnionType([lit1, lit2]), UnionType([lit2, lit3, lit1]) ) + def test_mixed_literal_types(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" , str -> str + # str , "x" -> str + # "x"?, str -> str + # str , "x"? -> str + self.assert_join(str1, str_type, str_type) + self.assert_join(str1_inst, str_type, str_type) + + # other operand is the same literal + # "x" , "x" -> "x" + # "x" , "x"? -> "x" + # "x"?, "x" -> "x" + # "x"?, "x"? -> "x"? + self.assert_join(str1, str1, str1) + self.assert_join(str1, str1_inst, str1) + self.assert_join(str1_inst, str1_inst, str1_inst) + + # other operand is a different literal + # "x" , "y" -> "x" | "y" (treat real literals like enum) + # "x" , "y"? -> str + # "x"?, "y" -> str + # "x"?, "y"? -> str + self.assert_simple_join(str1, str2, UnionType([str1, str2])) + self.assert_simple_join(str2, str1, UnionType([str2, str1])) + self.assert_join(str1, str2_inst, str_type) + self.assert_join(str1_inst, str2_inst, str_type) + def test_variadic_tuple_joins(self) -> None: # These tests really test just the "arity", to be sure it is handled correctly. self.assert_join( @@ -1308,6 +1378,39 @@ def test_literal_type(self) -> None: assert is_same_type(lit1, narrow_declared_type(lit1, a)) assert is_same_type(lit2, narrow_declared_type(lit2, a)) + def test_mixed_literal_types(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" , str -> "x" + # str , "x" -> "x" + # "x"?, str -> "x"? + # str , "x"? -> "x"? + self.assert_meet(str1, str_type, str1) + self.assert_meet(str1_inst, str_type, str1_inst) + + # other operand is the same literal + # "x" , "x" -> "x" + # "x" , "x"? -> "x" + # "x"?, "x" -> "x" + # "x"?, "x"? -> "x"? + self.assert_meet(str1, str1, str1) + self.assert_meet(str1, str1_inst, str1) + self.assert_meet(str1_inst, str1_inst, str1_inst) + + # other operand is a different literal + # "x" , "y" -> Never + # "x" , "y"? -> Never + # "x"?, "y" -> Never + # "x"?, "y"? -> str + self.assert_meet_uninhabited(str1, str2) + self.assert_meet_uninhabited(str1, str2_inst) + self.assert_meet(str1_inst, str2_inst, str_type) + # FIX generic interfaces + ranges def assert_meet_uninhabited(self, s: Type, t: Type) -> None: diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3c9290b8dbbb..7092ea68dc4a 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1169,7 +1169,7 @@ arr4 = [lit1, lit2, lit3] arr5 = [object(), lit1] reveal_type(arr1) # N: Revealed type is "builtins.list[Literal[1]]" -reveal_type(arr2) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(arr2) # N: Revealed type is "builtins.list[Union[Literal[1], Literal[2]]]" reveal_type(arr3) # N: Revealed type is "builtins.list[builtins.int]" reveal_type(arr4) # N: Revealed type is "builtins.list[builtins.object]" reveal_type(arr5) # N: Revealed type is "builtins.list[builtins.object]" diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index b14527963bd6..01c9525b0f35 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -2627,7 +2627,7 @@ def int_literal() -> None: case other: other # E: Statement is unreachable -def str_literal() -> None: +def str_literal_from_literal() -> None: match 'foo': case 'a' as s: reveal_type(s) # N: Revealed type is "Literal['a']" @@ -2636,6 +2636,16 @@ def str_literal() -> None: case other: other # E: Statement is unreachable + +def str_literal_from_str(arg: str) -> None: + match arg: + case 'a' as s: + reveal_type(s) # N: Revealed type is "Literal['a']" + case str(i): + reveal_type(i) # N: Revealed type is "builtins.str" + case other: + other # E: Statement is unreachable + [case testMatchOperations] # flags: --warn-unreachable From 03726cf48cca9ca5c8806082c2742cb547bbd3af Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 8 Aug 2025 19:49:13 +0200 Subject: [PATCH 08/12] change join behavior between literal and last_known_value --- mypy/join.py | 6 +++--- mypy/solve.py | 3 ++- mypy/test/testtypes.py | 6 +++--- test-data/unit/check-literal.test | 31 +++++++++++++++++++++---------- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index aa0970d0ba10..d3c39ad89350 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -277,7 +277,7 @@ def visit_union_type(self, t: UnionType) -> ProperType: elif isinstance(self.s, LiteralType): # E.g. join("x", "y" | "z") -> "x" | "y" | "z" # and join(1, "y" | "z") -> object - return mypy.typeops.make_simplified_union(join_types(self.s, x) for x in t.items) + return mypy.typeops.make_simplified_union([join_types(self.s, x) for x in t.items]) else: return mypy.typeops.make_simplified_union([self.s, t]) @@ -637,8 +637,8 @@ def visit_literal_type(self, t: LiteralType) -> ProperType: return mypy.typeops.make_simplified_union([self.s, t]) return join_types(self.s.fallback, t.fallback) elif isinstance(self.s, Instance) and self.s.last_known_value == t: - # E.g. Literal["x"], Literal["x"]? -> Literal["x"] - return t + # E.g. Literal["x"], Literal["x"]? -> Literal["x"]? + return self.s else: # E.g. Literal["x"], Literal["y"]? -> str return join_types(self.s, t.fallback) diff --git a/mypy/solve.py b/mypy/solve.py index fbbcac2520ad..19b7255d0307 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -319,7 +319,8 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: elif top is None: candidate = bottom elif is_subtype(bottom, top): - candidate = bottom + # Need to meet in case like Literal["x"]? <: T <: Literal["x"] + candidate = meet_types(bottom, top) else: candidate = None return candidate diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 9b8c063f3c00..2f5309fe680e 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1068,11 +1068,11 @@ def test_mixed_literal_types(self) -> None: # other operand is the same literal # "x" , "x" -> "x" - # "x" , "x"? -> "x" - # "x"?, "x" -> "x" + # "x" , "x"? -> "x"? + # "x"?, "x" -> "x"? # "x"?, "x"? -> "x"? self.assert_join(str1, str1, str1) - self.assert_join(str1, str1_inst, str1) + self.assert_join(str1, str1_inst, str1_inst) self.assert_join(str1_inst, str1_inst, str1_inst) # other operand is a different literal diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 7092ea68dc4a..03622f0e7b47 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2980,18 +2980,29 @@ z: Type[Literal[1, 2]] # E: Type[...] can't contain "Union[Literal[...], Litera [case testJoinLiteralAndInstance] from typing import Generic, TypeVar, Literal -T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T", covariant=False, contravariant=False) +S = TypeVar("S", covariant=False, contravariant=False) -class A(Generic[T]): ... +class A_inv(Generic[T]): ... +class A_co(Generic[T_co]): ... -def f(a: A[T], t: T) -> T: ... -def g(a: T, t: A[T]) -> T: ... +def check_inv(obj: A_inv[Literal[1]]) -> None: + def f(a: A_inv[S], t: S) -> S: ... + def g(a: S, t: A_inv[S]) -> S: ... -def check(obj: A[Literal[1]]) -> None: reveal_type(f(obj, 1)) # N: Revealed type is "Literal[1]" - reveal_type(f(obj, '')) # E: Cannot infer value of type parameter "T" of "f" \ - # N: Revealed type is "Any" + reveal_type(f(obj, '')) # E: Cannot infer value of type parameter "S" of "f" \ + # N: Revealed type is "Any" reveal_type(g(1, obj)) # N: Revealed type is "Literal[1]" - reveal_type(g('', obj)) # E: Cannot infer value of type parameter "T" of "g" \ - # N: Revealed type is "Any" -[builtins fixtures/tuple.pyi] + reveal_type(g('', obj)) # E: Cannot infer value of type parameter "S" of "g" \ + # N: Revealed type is "Any" + +def check_co(obj: A_co[Literal[1]]) -> None: + def f(a: A_co[S], t: S) -> S: ... + def g(a: S, t: A_co[S]) -> S: ... + + reveal_type(f(obj, 1)) # N: Revealed type is "builtins.int" + reveal_type(f(obj, '')) # N: Revealed type is "builtins.object" + reveal_type(g(1, obj)) # N: Revealed type is "builtins.int" + reveal_type(g('', obj)) # N: Revealed type is "builtins.object" From 0ab89c5f41822d94778d30dd60e683551bd91cff Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sat, 9 Aug 2025 09:38:19 +0200 Subject: [PATCH 09/12] make proper subtyping require last_known_value subtyping --- mypy/subtypes.py | 7 +++++++ mypy/test/testsubtypes.py | 22 +++++++++++----------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 0065133402b3..497689b740bb 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -549,6 +549,13 @@ def visit_instance(self, left: Instance) -> bool: assert isinstance(erased, Instance) t = erased nominal = True + if self.proper_subtype and right.last_known_value is not None: + if left.last_known_value is None: + # E.g. str is not a proper subtype of Literal["x"]? + nominal = False + else: + # E.g. Literal[A]? <: Literal[B]? requires A <: B + nominal &= self._is_subtype(left.last_known_value, right.last_known_value) if right.type.has_type_var_tuple_type: # For variadic instances we simply find the correct type argument mappings, # all the heavy lifting is done by the tuple subtyping. diff --git a/mypy/test/testsubtypes.py b/mypy/test/testsubtypes.py index 9c191ab3cfd8..5be32f628de1 100644 --- a/mypy/test/testsubtypes.py +++ b/mypy/test/testsubtypes.py @@ -304,7 +304,7 @@ def test_literal(self) -> None: self.assert_subtype(str1_inst, str1) self.assert_subtype(str1_inst, str1_inst) - # second operand is a different literal + # other operand is a different literal # "x" ≲ "y" -> NO # "x" ≲ "y"? -> YES # "x"? ≲ "y" -> NO @@ -315,17 +315,17 @@ def test_literal(self) -> None: self.assert_subtype(str1_inst, str2_inst) # check proper subtyping - # second operand is the fallback type + # other operand is the fallback type # "x" <: str -> YES # str <: "x" -> NO # "x"? <: str -> YES - # str <: "x"? -> YES + # str <: "x"? -> NO self.assert_proper_subtype(str1, str_type) self.assert_not_proper_subtype(str_type, str1) self.assert_proper_subtype(str1_inst, str_type) - self.assert_proper_subtype(str_type, str1_inst) + self.assert_not_proper_subtype(str_type, str1_inst) - # second operand is the same literal + # other operand is the same literal # "x" <: "x" -> YES # "x" <: "x"? -> YES # "x"? <: "x" -> NO @@ -335,15 +335,15 @@ def test_literal(self) -> None: self.assert_not_proper_subtype(str1_inst, str1) self.assert_proper_subtype(str1_inst, str1_inst) - # second operand is a different literal - # "x" ≲ "y" -> NO - # "x" ≲ "y"? -> NO - # "x"? ≲ "y" -> NO - # "x"? ≲ "y"? -> YES + # other operand is a different literal + # "x" <: "y" -> NO + # "x" <: "y"? -> NO + # "x"? <: "y" -> NO + # "x"? <: "y"? -> NO self.assert_not_proper_subtype(str1, str2) self.assert_not_proper_subtype(str1, str2_inst) self.assert_not_proper_subtype(str1_inst, str2) - self.assert_proper_subtype(str1_inst, str2_inst) + self.assert_not_proper_subtype(str1_inst, str2_inst) # IDEA: Maybe add these test cases (they are tested pretty well in type # checker tests already): From 9da85c82e02e678850d28972eee581bf94d7705e Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sat, 9 Aug 2025 09:43:18 +0200 Subject: [PATCH 10/12] revert union behavior --- mypy/join.py | 8 +------- mypy/test/testtypes.py | 8 +++----- test-data/unit/check-literal.test | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index d3c39ad89350..30c77bbebf7d 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -274,10 +274,6 @@ def visit_unbound_type(self, t: UnboundType) -> ProperType: def visit_union_type(self, t: UnionType) -> ProperType: if is_proper_subtype(self.s, t): return t - elif isinstance(self.s, LiteralType): - # E.g. join("x", "y" | "z") -> "x" | "y" | "z" - # and join(1, "y" | "z") -> object - return mypy.typeops.make_simplified_union([join_types(self.s, x) for x in t.items]) else: return mypy.typeops.make_simplified_union([self.s, t]) @@ -631,9 +627,7 @@ def visit_literal_type(self, t: LiteralType) -> ProperType: if t == self.s: # E.g. Literal["x"], Literal["x"] -> Literal["x"] return t - if (self.s.fallback.type == t.fallback.type) or ( - self.s.fallback.type.is_enum and t.fallback.type.is_enum - ): + if self.s.fallback.type.is_enum and t.fallback.type.is_enum: return mypy.typeops.make_simplified_union([self.s, t]) return join_types(self.s.fallback, t.fallback) elif isinstance(self.s, Instance) and self.s.last_known_value == t: diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 2f5309fe680e..bac0da779756 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1027,8 +1027,7 @@ def test_literal_type(self) -> None: self.assert_join(lit1, lit1, lit1) self.assert_join(lit1, a, a) self.assert_join(lit1, d, self.fx.o) - self.assert_simple_join(lit1, lit2, UnionType([lit1, lit2])) - self.assert_simple_join(lit2, lit1, UnionType([lit2, lit1])) + self.assert_join(lit1, lit2, a) self.assert_join(lit1, lit3, self.fx.o) self.assert_join(lit1, self.fx.anyt, self.fx.anyt) self.assert_join(UnionType([lit1, lit2]), lit2, UnionType([lit1, lit2])) @@ -1076,12 +1075,11 @@ def test_mixed_literal_types(self) -> None: self.assert_join(str1_inst, str1_inst, str1_inst) # other operand is a different literal - # "x" , "y" -> "x" | "y" (treat real literals like enum) + # "x" , "y" -> str (TODO: consider using "x" | "y" (treat real literals like enum)) # "x" , "y"? -> str # "x"?, "y" -> str # "x"?, "y"? -> str - self.assert_simple_join(str1, str2, UnionType([str1, str2])) - self.assert_simple_join(str2, str1, UnionType([str2, str1])) + self.assert_join(str1, str2, str_type) self.assert_join(str1, str2_inst, str_type) self.assert_join(str1_inst, str2_inst, str_type) diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 03622f0e7b47..89b4a9d809a2 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1169,7 +1169,7 @@ arr4 = [lit1, lit2, lit3] arr5 = [object(), lit1] reveal_type(arr1) # N: Revealed type is "builtins.list[Literal[1]]" -reveal_type(arr2) # N: Revealed type is "builtins.list[Union[Literal[1], Literal[2]]]" +reveal_type(arr2) # N: Revealed type is "builtins.list[builtins.int]" reveal_type(arr3) # N: Revealed type is "builtins.list[builtins.int]" reveal_type(arr4) # N: Revealed type is "builtins.list[builtins.object]" reveal_type(arr5) # N: Revealed type is "builtins.list[builtins.object]" From cb94785bec81ba12f338762a0b55cf0d406c03ba Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sat, 9 Aug 2025 09:48:48 +0200 Subject: [PATCH 11/12] remove testCastFromLiteralRedundant --- test-data/unit/check-warnings.test | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test-data/unit/check-warnings.test b/test-data/unit/check-warnings.test index a2d201fa301d..acc122d8fb89 100644 --- a/test-data/unit/check-warnings.test +++ b/test-data/unit/check-warnings.test @@ -49,13 +49,6 @@ from typing import cast a = 1 b = cast(object, 1) -[case testCastFromLiteralRedundant] -# flags: --warn-redundant-casts -from typing import cast - -cast(int, 1) -[out] -main:4: error: Redundant cast to "int" [case testCastFromUnionOfAnyOk] # flags: --warn-redundant-casts From 2059a1a6f5cc9b558be05b8c104e3fa83baaec0d Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sat, 9 Aug 2025 12:39:32 +0200 Subject: [PATCH 12/12] fix enum+literal getting coerced to Sequence[str] --- mypy/join.py | 6 ++++++ test-data/unit/check-literal.test | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/mypy/join.py b/mypy/join.py index 30c77bbebf7d..49ddbc35373c 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -150,6 +150,12 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: t, s, subtype_context=SubtypeContext(ignore_type_params=True) ): result = self.join_instances_via_supertype(t, s) + elif s.type.bases and is_proper_subtype( + s, t, subtype_context=SubtypeContext(ignore_type_params=True) + ): + result = self.join_instances_via_supertype(s, t) + elif is_subtype(t, s, subtype_context=SubtypeContext(ignore_type_params=True)): + result = self.join_instances_via_supertype(t, s) else: # Now t is not a subtype of s, and t != s. Now s could be a subtype # of t; alternatively, we need to find a common supertype. This works diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 89b4a9d809a2..2e452031c092 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -3006,3 +3006,20 @@ def check_co(obj: A_co[Literal[1]]) -> None: reveal_type(f(obj, '')) # N: Revealed type is "builtins.object" reveal_type(g(1, obj)) # N: Revealed type is "builtins.int" reveal_type(g('', obj)) # N: Revealed type is "builtins.object" + +[case testJoinLiteralInstanceAndEnum] +from typing import Final, TypeVar +from enum import StrEnum + +T = TypeVar("T") +def join(a: T, b: T) -> T: ... + +class Foo(StrEnum): + A = "a" + +CONST: Final = "const" + +reveal_type(CONST) # N: Revealed type is "Literal['const']?" +reveal_type(join(Foo.A, CONST)) # N: Revealed type is "builtins.str" +reveal_type(join(CONST, Foo.A)) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi]