Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def can_coerce_to(src: RType, dest: RType) -> bool:
if isinstance(src, RPrimitive):
# If either src or dest is a disjoint type, then they must both be.
if src.name in disjoint_types and dest.name in disjoint_types:
return src.name == dest.name
return src.name == dest.name or (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there must be a better way to do this

src.name in ("builtins.dict", "builtins.dict[exact]")
and dest.name in ("builtins.dict", "builtins.dict[exact]")
)
return src.size == dest.size
if isinstance(src, RInstance):
return is_object_rprimitive(dest)
Expand Down
16 changes: 14 additions & 2 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def __hash__(self) -> int:
"builtins.list", is_unboxed=False, is_refcounted=True, may_be_immortal=False
)

# Python dict object (or an instance of a subclass of dict).
# Python dict object.
exact_dict_rprimitive: Final = RPrimitive(
"builtins.dict[exact]", is_unboxed=False, is_refcounted=True
)

# An instance of a subclass of dict.
dict_rprimitive: Final = RPrimitive("builtins.dict", is_unboxed=False, is_refcounted=True)

# Python set object (or an instance of a subclass of set).
Expand Down Expand Up @@ -608,7 +613,14 @@ def is_list_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:


def is_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict"
return isinstance(rtype, RPrimitive) and rtype.name in (
"builtins.dict",
"builtins.dict[exact]",
)


def is_exact_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict[exact]"


def is_set_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
Expand Down
3 changes: 2 additions & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
c_pyssize_t_rprimitive,
cstring_rprimitive,
dict_rprimitive,
exact_dict_rprimitive,
float_rprimitive,
int_rprimitive,
none_rprimitive,
Expand Down Expand Up @@ -160,7 +161,7 @@
# Get the sys.modules dictionary
get_module_dict_op = custom_op(
arg_types=[],
return_type=dict_rprimitive,
return_type=exact_dict_rprimitive,
c_function_name="PyImport_GetModuleDict",
error_kind=ERR_NEVER,
is_borrowed=True,
Expand Down
4 changes: 4 additions & 0 deletions mypyc/rt_subtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
RVoid,
is_bit_rprimitive,
is_bool_rprimitive,
is_dict_rprimitive,
is_exact_dict_rprimitive,
is_int_rprimitive,
is_short_int_rprimitive,
)
Expand Down Expand Up @@ -58,6 +60,8 @@ def visit_rprimitive(self, left: RPrimitive) -> bool:
return True
if is_bit_rprimitive(left) and is_bool_rprimitive(self.right):
return True
if is_exact_dict_rprimitive(left) and is_dict_rprimitive(self.right):
return True
return left is self.right

def visit_rtuple(self, left: RTuple) -> bool:
Expand Down
5 changes: 5 additions & 0 deletions mypyc/subtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
RVoid,
is_bit_rprimitive,
is_bool_rprimitive,
is_dict_rprimitive,
is_exact_dict_rprimitive,
is_fixed_width_rtype,
is_int_rprimitive,
is_object_rprimitive,
Expand Down Expand Up @@ -67,6 +69,9 @@ def visit_rprimitive(self, left: RPrimitive) -> bool:
elif is_fixed_width_rtype(left):
if is_int_rprimitive(right):
return True
elif is_exact_dict_rprimitive(left):
if is_dict_rprimitive(right):
return True
return left is right

def visit_rtuple(self, left: RTuple) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3270,7 +3270,7 @@ def root():
r4 :: str
r5 :: object
r6 :: str
r7 :: dict
r7 :: dict[exact]
r8 :: str
r9 :: object
r10 :: i32
Expand All @@ -3281,7 +3281,7 @@ def root():
r16 :: str
r17 :: object
r18 :: str
r19 :: dict
r19 :: dict[exact]
r20 :: str
r21 :: object
r22 :: i32
Expand Down Expand Up @@ -3327,12 +3327,12 @@ def submodule():
r4 :: str
r5 :: object
r6 :: str
r7 :: dict
r7 :: dict[exact]
r8 :: str
r9 :: object
r10 :: i32
r11 :: bit
r12 :: dict
r12 :: dict[exact]
r13 :: str
r14 :: object
r15 :: str
Expand Down
Loading