Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def is_undefined(self, name: str) -> bool:
class Loop:
def __init__(self) -> None:
self.has_break = False
# variables defined in every loop branch with `break`
self.break_vars: set[str] | None = None


class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor):
Expand Down Expand Up @@ -472,6 +474,9 @@ def visit_for_stmt(self, o: ForStmt) -> None:
has_break = loop.has_break
if has_break:
self.tracker.start_branch_statement()
if loop.break_vars is not None:
for bv in loop.break_vars:
self.tracker.record_definition(bv)
self.tracker.next_branch()
o.else_body.accept(self)
if has_break:
Expand Down Expand Up @@ -504,6 +509,14 @@ def visit_break_stmt(self, o: BreakStmt) -> None:
super().visit_break_stmt(o)
if self.loops:
self.loops[-1].has_break = True
# Track variables that are definitely defined at the point of break
if len(self.tracker._scope().branch_stmts) > 0:
branch = self.tracker._scope().branch_stmts[-1].branches[-1]
if self.loops[-1].break_vars is None:
self.loops[-1].break_vars = set(branch.must_be_defined)
else:
# we only want variables that have been defined in each branch
self.loops[-1].break_vars.intersection_update(branch.must_be_defined)
self.tracker.skip_branch()

def visit_expression_stmt(self, o: ExpressionStmt) -> None:
Expand Down
123 changes: 123 additions & 0 deletions test-data/unit/check-possibly-undefined.test
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,126 @@ def foo(x: Union[int, str]) -> None:
assert_never(x)
f # OK
[builtins fixtures/tuple.pyi]

[case testForElseWithBreakInTryExceptContinue]
# flags: --enable-error-code possibly-undefined
# Test for issue where variable defined before break in try block
# was incorrectly reported as undefined when except has continue
def random() -> float: return 0.5

if random():
for i in range(10):
try:
value = random()
break
except Exception:
continue
else:
raise RuntimeError

print(value) # Should not error - value is defined if we broke
else:
value = random()

print(value) # Should not error
[builtins fixtures/for_else_exception.pyi]
[typing fixtures/typing-medium.pyi]

[case testForElseWithBreakInTryExceptContinueNoIf]
# flags: --enable-error-code possibly-undefined
# Simpler version without if statement
def random() -> float: return 0.5

for i in range(10):
try:
value = random()
break
except Exception:
continue
else:
raise RuntimeError

print(value) # Should not error
[builtins fixtures/for_else_exception.pyi]
[typing fixtures/typing-medium.pyi]

[case testForElseWithBreakInTryExceptPass]
# flags: --enable-error-code possibly-undefined
# Version with pass instead of continue - should also work
def random() -> float: return 0.5

if random():
for i in range(10):
try:
value = random()
break
except Exception:
pass
else:
raise RuntimeError

print(value) # Should not error
else:
value = random()

print(value) # Should not error
[builtins fixtures/for_else_exception.pyi]
[typing fixtures/typing-medium.pyi]

[case testForElseWithConditionalDefBeforeBreak]
# flags: --enable-error-code possibly-undefined
# Test that conditional definition before break still works correctly
def random() -> float: return 0.5

if random():
for i in range(10):
try:
if i > 10:
value = random()
break
except Exception:
continue
else:
raise RuntimeError

print(value) # Should not error (though might be undefined at runtime)
else:
value = random()

print(value) # Should not error
[builtins fixtures/for_else_exception.pyi]
[typing fixtures/typing-medium.pyi]

[case testForElseDefineInBothBranches]
# flags: --enable-error-code possibly-undefined
# Test that variable defined in both for break and else branches is not undefined
for i in range(10):
if i:
value = i
break
else:
value = 1

print(value) # Should not error - value is defined in all paths
[builtins fixtures/for_else_exception.pyi]
[typing fixtures/typing-medium.pyi]

[case testForElseWithWalrusInBreak]
# flags: --enable-error-code possibly-undefined
# Test with walrus operator in if condition before break
def random() -> float: return 0.5

if random():
for i in range(10):
if value := random():
break
else:
raise RuntimeError

print(value) # Should not error - value is defined if we broke
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add one two tests where it should error? Like, e.g, if raise above is conditional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I will add the test.

else:
value = random()

print(value) # Should not error
[builtins fixtures/for_else_exception.pyi]
[typing fixtures/typing-medium.pyi]
54 changes: 54 additions & 0 deletions test-data/unit/fixtures/for_else_exception.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Fixture for for-else tests with exceptions
# Combines needed elements from primitives.pyi and exception.pyi

from typing import Generic, Iterator, Mapping, Sequence, TypeVar

T = TypeVar('T')
V = TypeVar('V')

class object:
def __init__(self) -> None: pass
class type:
def __init__(self, x: object) -> None: pass
class int:
def __init__(self, x: object = ..., base: int = ...) -> None: pass
def __add__(self, i: int) -> int: pass
def __rmul__(self, x: int) -> int: pass
def __bool__(self) -> bool: pass
def __eq__(self, x: object) -> bool: pass
def __ne__(self, x: object) -> bool: pass
def __lt__(self, x: 'int') -> bool: pass
def __le__(self, x: 'int') -> bool: pass
def __gt__(self, x: 'int') -> bool: pass
def __ge__(self, x: 'int') -> bool: pass
class float:
def __float__(self) -> float: pass
def __add__(self, x: float) -> float: pass
def hex(self) -> str: pass
class bool(int): pass
class str(Sequence[str]):
def __add__(self, s: str) -> str: pass
def __iter__(self) -> Iterator[str]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> str: pass
def format(self, *args: object, **kwargs: object) -> str: pass
class dict(Mapping[T, V]):
def __iter__(self) -> Iterator[T]: pass
class tuple(Generic[T]):
def __contains__(self, other: object) -> bool: pass
class ellipsis: pass

class BaseException:
def __init__(self, *args: object) -> None: ...
class Exception(BaseException): pass
class RuntimeError(Exception): pass

class range(Sequence[int]):
def __init__(self, __x: int, __y: int = ..., __z: int = ...) -> None: pass
def count(self, value: int) -> int: pass
def index(self, value: int) -> int: pass
def __getitem__(self, i: int) -> int: pass
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass

def print(x: object) -> None: pass
Loading