From ce69b76650223067f3a01b134d07ff2f1a8edae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Langa?= Date: Wed, 20 Aug 2025 13:09:09 +0200 Subject: [PATCH 1/3] Fix spurious possibly-undefined errors in for-else with break When a for loop contains branches with `break` and an `else` block, variables declared inside those branches were incorrectly discarded from further analysis, leading Mypy to incorrectly report a variable as undefined after the loop or as used before declaration. With this fix, when a for loop's `else` block is considered, variables declared in every branch of the `for` loop body that called `break` are now considered as defined within the body of the loop. Fixes #14209 Fixes #19690 --- mypy/partially_defined.py | 69 ++++++++++ test-data/unit/check-possibly-undefined.test | 123 ++++++++++++++++++ .../unit/fixtures/for_else_exception.pyi | 54 ++++++++ 3 files changed, 246 insertions(+) create mode 100644 test-data/unit/fixtures/for_else_exception.pyi diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 38154cf697e1..03cf62b0d159 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -48,6 +48,16 @@ from mypy.types import Type, UninhabitedType, get_proper_type +def _ambv(s: str) -> None: + assert s + pass # print("DEBUG:", s) + + +def _ambv_cont(s: str) -> None: + assert s + pass # print(s) + + class BranchState: """BranchState contains information about variable definition at the end of a branching statement. `if` and `match` are examples of branching statements. @@ -117,6 +127,9 @@ def delete_var(self, name: str) -> None: def record_nested_branch(self, state: BranchState) -> None: assert len(self.branches) > 0 current_branch = self.branches[-1] + _ambv( + f"record_nested_branch: state.must={state.must_be_defined if 'value' in state.must_be_defined else '...'}, state.may={'value' if 'value' in state.may_be_defined else '...'}, state.skipped={state.skipped}" + ) if state.skipped: current_branch.skipped = True return @@ -154,6 +167,17 @@ def done(self) -> BranchState: all_vars.update(b.must_be_defined) # For the rest of the things, we only care about branches that weren't skipped. non_skipped_branches = [b for b in self.branches if not b.skipped] + import sys + + _called_by = sys._getframe(2).f_code.co_name + _ambv( + f"done {_called_by}: branches={len(self.branches)}, non_skipped={len(non_skipped_branches)}" + ) + for i, b in enumerate(self.branches): + has_value = "value" in b.must_be_defined or "value" in b.may_be_defined + _ambv_cont( + f" Branch {i}: has_value={has_value}, skipped={b.skipped}, must={b.must_be_defined}, may={b.may_be_defined}" + ) if non_skipped_branches: must_be_defined = non_skipped_branches[0].must_be_defined for b in non_skipped_branches[1:]: @@ -163,6 +187,7 @@ def done(self) -> BranchState: # Everything that wasn't defined in all branches but was defined # in at least one branch should be in `may_be_defined`! may_be_defined = all_vars.difference(must_be_defined) + _ambv_cont(f" Result: must={must_be_defined}, may={may_be_defined}") return BranchState( must_be_defined=must_be_defined, may_be_defined=may_be_defined, @@ -295,6 +320,8 @@ def is_undefined(self, name: str) -> bool: class Loop: def __init__(self) -> None: self.has_break = False + # variables defined in every loop branch with `break` + self.break_vars: set[str] | None = None class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor): @@ -336,6 +363,10 @@ def __init__( for name in implicit_module_attrs: self.tracker.record_definition(name) + # def visit_block(self, block: Block, /) -> None: + # _ambv(f"PossiblyUndefinedVariableVisitor visiting {block}") + # super().visit_block(block) + def var_used_before_def(self, name: str, context: Context) -> None: if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF): self.msg.var_used_before_def(name, context) @@ -349,6 +380,9 @@ def process_definition(self, name: str) -> None: if not self.tracker.in_scope(ScopeType.Class): refs = self.tracker.pop_undefined_ref(name) for ref in refs: + _ambv( + f"process_definition for {name}, ref at line {ref.line}, loops={bool(self.loops)}" + ) if self.loops: self.variable_may_be_undefined(name, ref) else: @@ -370,6 +404,9 @@ def visit_nonlocal_decl(self, o: NonlocalDecl) -> None: def process_lvalue(self, lvalue: Lvalue | None) -> None: if isinstance(lvalue, NameExpr): + _ambv( + f"process_lvalue calling process_definition for {lvalue.name} at line {lvalue.line}" + ) self.process_definition(lvalue.name) elif isinstance(lvalue, StarExpr): self.process_lvalue(lvalue.expr) @@ -378,6 +415,7 @@ def process_lvalue(self, lvalue: Lvalue | None) -> None: self.process_lvalue(item) def visit_assignment_stmt(self, o: AssignmentStmt) -> None: + _ambv(f"visit_assignment_stmt at line {o.line}") for lvalue in o.lvalues: self.process_lvalue(lvalue) super().visit_assignment_stmt(o) @@ -456,22 +494,39 @@ def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: self.tracker.exit_scope() def visit_for_stmt(self, o: ForStmt) -> None: + _ambv(f"visit_for_stmt: line {o.line}") o.expr.accept(self) self.process_lvalue(o.index) o.index.accept(self) self.tracker.start_branch_statement() loop = Loop() self.loops.append(loop) + _ambv( + f"visit_for_stmt: Before body, current state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}" + ) o.body.accept(self) + _ambv(f"visit_for_stmt: after body, has_break={loop.has_break}") + _ambv( + f"visit_for_stmt: After body, current state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}" + ) self.tracker.next_branch() + _ambv( + f"visit_for_stmt: After next_branch, new branch state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}" + ) self.tracker.end_branch_statement() if o.else_body is not None: # If the loop has a `break` inside, `else` is executed conditionally. # If the loop doesn't have a `break` either the function will return or # execute the `else`. has_break = loop.has_break + _ambv( + f"visit_for_stmt: else_body present, has_break={has_break}, break_vars={loop.break_vars}" + ) if has_break: self.tracker.start_branch_statement() + if loop.break_vars is not None: + for bv in loop.break_vars: + self.tracker.record_definition(bv) self.tracker.next_branch() o.else_body.accept(self) if has_break: @@ -497,6 +552,7 @@ def visit_raise_stmt(self, o: RaiseStmt) -> None: self.tracker.skip_branch() def visit_continue_stmt(self, o: ContinueStmt) -> None: + _ambv(f"continue at line {o.line}, skipping branch") super().visit_continue_stmt(o) self.tracker.skip_branch() @@ -504,6 +560,14 @@ def visit_break_stmt(self, o: BreakStmt) -> None: super().visit_break_stmt(o) if self.loops: self.loops[-1].has_break = True + # Track variables that are definitely defined at the point of break + if len(self.tracker._scope().branch_stmts) > 0: + branch = self.tracker._scope().branch_stmts[-1].branches[-1] + if self.loops[-1].break_vars is None: + self.loops[-1].break_vars = set(branch.must_be_defined) + else: + # we only want variables that have been defined in each branch + self.loops[-1].break_vars.intersection_update(branch.must_be_defined) self.tracker.skip_branch() def visit_expression_stmt(self, o: ExpressionStmt) -> None: @@ -545,6 +609,7 @@ def f() -> int: self.try_depth -= 1 def process_try_stmt(self, o: TryStmt) -> None: + _ambv(f"process_try_stmt: line {o.line}, handlers={len(o.handlers)}") """ Processes try statement decomposing it into the following: if ...: @@ -620,6 +685,9 @@ def visit_starred_pattern(self, o: StarredPattern) -> None: def visit_name_expr(self, o: NameExpr) -> None: if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global): return + _ambv( + f"visit_name_expr {o.name} at line {o.line}, possibly_undefined={self.tracker.is_possibly_undefined(o.name)}, defined_in_different_branch={self.tracker.is_defined_in_different_branch(o.name)}, is_undefined={self.tracker.is_undefined(o.name)}" + ) if self.tracker.is_possibly_undefined(o.name): # A variable is only defined in some branches. self.variable_may_be_undefined(o.name, o) @@ -640,6 +708,7 @@ def visit_name_expr(self, o: NameExpr) -> None: # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should # be caught by this visitor. Save the ref for later, so that if we see a definition, # we know it's a used-before-definition scenario. + _ambv(f"Recording undefined ref for {o.name} at line {o.line}") self.tracker.record_undefined_ref(o) super().visit_name_expr(o) diff --git a/test-data/unit/check-possibly-undefined.test b/test-data/unit/check-possibly-undefined.test index ae277949c049..9dd9a3d7969d 100644 --- a/test-data/unit/check-possibly-undefined.test +++ b/test-data/unit/check-possibly-undefined.test @@ -1043,3 +1043,126 @@ def foo(x: Union[int, str]) -> None: assert_never(x) f # OK [builtins fixtures/tuple.pyi] + +[case testForElseWithBreakInTryExceptContinue] +# flags: --enable-error-code possibly-undefined +# Test for issue where variable defined before break in try block +# was incorrectly reported as undefined when except has continue +def random() -> float: return 0.5 + +if random(): + for i in range(10): + try: + value = random() + break + except Exception: + continue + else: + raise RuntimeError + + print(value) # Should not error - value is defined if we broke +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithBreakInTryExceptContinueNoIf] +# flags: --enable-error-code possibly-undefined +# Simpler version without if statement +def random() -> float: return 0.5 + +for i in range(10): + try: + value = random() + break + except Exception: + continue +else: + raise RuntimeError + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithBreakInTryExceptPass] +# flags: --enable-error-code possibly-undefined +# Version with pass instead of continue - should also work +def random() -> float: return 0.5 + +if random(): + for i in range(10): + try: + value = random() + break + except Exception: + pass + else: + raise RuntimeError + + print(value) # Should not error +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithConditionalDefBeforeBreak] +# flags: --enable-error-code possibly-undefined +# Test that conditional definition before break still works correctly +def random() -> float: return 0.5 + +if random(): + for i in range(10): + try: + if i > 10: + value = random() + break + except Exception: + continue + else: + raise RuntimeError + + print(value) # Should not error (though might be undefined at runtime) +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseDefineInBothBranches] +# flags: --enable-error-code possibly-undefined +# Test that variable defined in both for break and else branches is not undefined +for i in range(10): + if i: + value = i + break +else: + value = 1 + +print(value) # Should not error - value is defined in all paths +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithWalrusInBreak] +# flags: --enable-error-code possibly-undefined +# Test with walrus operator in if condition before break +def random() -> float: return 0.5 + +if random(): + for i in range(10): + if value := random(): + break + else: + raise RuntimeError + + print(value) # Should not error - value is defined if we broke +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] diff --git a/test-data/unit/fixtures/for_else_exception.pyi b/test-data/unit/fixtures/for_else_exception.pyi new file mode 100644 index 000000000000..98c953caff72 --- /dev/null +++ b/test-data/unit/fixtures/for_else_exception.pyi @@ -0,0 +1,54 @@ +# Fixture for for-else tests with exceptions +# Combines needed elements from primitives.pyi and exception.pyi + +from typing import Generic, Iterator, Mapping, Sequence, TypeVar + +T = TypeVar('T') +V = TypeVar('V') + +class object: + def __init__(self) -> None: pass +class type: + def __init__(self, x: object) -> None: pass +class int: + def __init__(self, x: object = ..., base: int = ...) -> None: pass + def __add__(self, i: int) -> int: pass + def __rmul__(self, x: int) -> int: pass + def __bool__(self) -> bool: pass + def __eq__(self, x: object) -> bool: pass + def __ne__(self, x: object) -> bool: pass + def __lt__(self, x: 'int') -> bool: pass + def __le__(self, x: 'int') -> bool: pass + def __gt__(self, x: 'int') -> bool: pass + def __ge__(self, x: 'int') -> bool: pass +class float: + def __float__(self) -> float: pass + def __add__(self, x: float) -> float: pass + def hex(self) -> str: pass +class bool(int): pass +class str(Sequence[str]): + def __add__(self, s: str) -> str: pass + def __iter__(self) -> Iterator[str]: pass + def __contains__(self, other: object) -> bool: pass + def __getitem__(self, item: int) -> str: pass + def format(self, *args: object, **kwargs: object) -> str: pass +class dict(Mapping[T, V]): + def __iter__(self) -> Iterator[T]: pass +class tuple(Generic[T]): + def __contains__(self, other: object) -> bool: pass +class ellipsis: pass + +class BaseException: + def __init__(self, *args: object) -> None: ... +class Exception(BaseException): pass +class RuntimeError(Exception): pass + +class range(Sequence[int]): + def __init__(self, __x: int, __y: int = ..., __z: int = ...) -> None: pass + def count(self, value: int) -> int: pass + def index(self, value: int) -> int: pass + def __getitem__(self, i: int) -> int: pass + def __iter__(self) -> Iterator[int]: pass + def __contains__(self, other: object) -> bool: pass + +def print(x: object) -> None: pass From 1442d7782fe140cf8c5352ddf1f1c3be0a8fe024 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:10:04 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test-data/unit/check-possibly-undefined.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-possibly-undefined.test b/test-data/unit/check-possibly-undefined.test index 9dd9a3d7969d..3bf52e5b8847 100644 --- a/test-data/unit/check-possibly-undefined.test +++ b/test-data/unit/check-possibly-undefined.test @@ -1158,7 +1158,7 @@ if random(): break else: raise RuntimeError - + print(value) # Should not error - value is defined if we broke else: value = random() From dc725bf6ecd2e2293c75e5d490c57a0af909a65f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Langa?= Date: Wed, 20 Aug 2025 17:24:08 +0200 Subject: [PATCH 3/3] Remove debug prints --- mypy/partially_defined.py | 56 --------------------------------------- 1 file changed, 56 deletions(-) diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 03cf62b0d159..ada28ec4bd3c 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -48,16 +48,6 @@ from mypy.types import Type, UninhabitedType, get_proper_type -def _ambv(s: str) -> None: - assert s - pass # print("DEBUG:", s) - - -def _ambv_cont(s: str) -> None: - assert s - pass # print(s) - - class BranchState: """BranchState contains information about variable definition at the end of a branching statement. `if` and `match` are examples of branching statements. @@ -127,9 +117,6 @@ def delete_var(self, name: str) -> None: def record_nested_branch(self, state: BranchState) -> None: assert len(self.branches) > 0 current_branch = self.branches[-1] - _ambv( - f"record_nested_branch: state.must={state.must_be_defined if 'value' in state.must_be_defined else '...'}, state.may={'value' if 'value' in state.may_be_defined else '...'}, state.skipped={state.skipped}" - ) if state.skipped: current_branch.skipped = True return @@ -167,17 +154,6 @@ def done(self) -> BranchState: all_vars.update(b.must_be_defined) # For the rest of the things, we only care about branches that weren't skipped. non_skipped_branches = [b for b in self.branches if not b.skipped] - import sys - - _called_by = sys._getframe(2).f_code.co_name - _ambv( - f"done {_called_by}: branches={len(self.branches)}, non_skipped={len(non_skipped_branches)}" - ) - for i, b in enumerate(self.branches): - has_value = "value" in b.must_be_defined or "value" in b.may_be_defined - _ambv_cont( - f" Branch {i}: has_value={has_value}, skipped={b.skipped}, must={b.must_be_defined}, may={b.may_be_defined}" - ) if non_skipped_branches: must_be_defined = non_skipped_branches[0].must_be_defined for b in non_skipped_branches[1:]: @@ -187,7 +163,6 @@ def done(self) -> BranchState: # Everything that wasn't defined in all branches but was defined # in at least one branch should be in `may_be_defined`! may_be_defined = all_vars.difference(must_be_defined) - _ambv_cont(f" Result: must={must_be_defined}, may={may_be_defined}") return BranchState( must_be_defined=must_be_defined, may_be_defined=may_be_defined, @@ -363,10 +338,6 @@ def __init__( for name in implicit_module_attrs: self.tracker.record_definition(name) - # def visit_block(self, block: Block, /) -> None: - # _ambv(f"PossiblyUndefinedVariableVisitor visiting {block}") - # super().visit_block(block) - def var_used_before_def(self, name: str, context: Context) -> None: if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF): self.msg.var_used_before_def(name, context) @@ -380,9 +351,6 @@ def process_definition(self, name: str) -> None: if not self.tracker.in_scope(ScopeType.Class): refs = self.tracker.pop_undefined_ref(name) for ref in refs: - _ambv( - f"process_definition for {name}, ref at line {ref.line}, loops={bool(self.loops)}" - ) if self.loops: self.variable_may_be_undefined(name, ref) else: @@ -404,9 +372,6 @@ def visit_nonlocal_decl(self, o: NonlocalDecl) -> None: def process_lvalue(self, lvalue: Lvalue | None) -> None: if isinstance(lvalue, NameExpr): - _ambv( - f"process_lvalue calling process_definition for {lvalue.name} at line {lvalue.line}" - ) self.process_definition(lvalue.name) elif isinstance(lvalue, StarExpr): self.process_lvalue(lvalue.expr) @@ -415,7 +380,6 @@ def process_lvalue(self, lvalue: Lvalue | None) -> None: self.process_lvalue(item) def visit_assignment_stmt(self, o: AssignmentStmt) -> None: - _ambv(f"visit_assignment_stmt at line {o.line}") for lvalue in o.lvalues: self.process_lvalue(lvalue) super().visit_assignment_stmt(o) @@ -494,34 +458,20 @@ def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: self.tracker.exit_scope() def visit_for_stmt(self, o: ForStmt) -> None: - _ambv(f"visit_for_stmt: line {o.line}") o.expr.accept(self) self.process_lvalue(o.index) o.index.accept(self) self.tracker.start_branch_statement() loop = Loop() self.loops.append(loop) - _ambv( - f"visit_for_stmt: Before body, current state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}" - ) o.body.accept(self) - _ambv(f"visit_for_stmt: after body, has_break={loop.has_break}") - _ambv( - f"visit_for_stmt: After body, current state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}" - ) self.tracker.next_branch() - _ambv( - f"visit_for_stmt: After next_branch, new branch state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}" - ) self.tracker.end_branch_statement() if o.else_body is not None: # If the loop has a `break` inside, `else` is executed conditionally. # If the loop doesn't have a `break` either the function will return or # execute the `else`. has_break = loop.has_break - _ambv( - f"visit_for_stmt: else_body present, has_break={has_break}, break_vars={loop.break_vars}" - ) if has_break: self.tracker.start_branch_statement() if loop.break_vars is not None: @@ -552,7 +502,6 @@ def visit_raise_stmt(self, o: RaiseStmt) -> None: self.tracker.skip_branch() def visit_continue_stmt(self, o: ContinueStmt) -> None: - _ambv(f"continue at line {o.line}, skipping branch") super().visit_continue_stmt(o) self.tracker.skip_branch() @@ -609,7 +558,6 @@ def f() -> int: self.try_depth -= 1 def process_try_stmt(self, o: TryStmt) -> None: - _ambv(f"process_try_stmt: line {o.line}, handlers={len(o.handlers)}") """ Processes try statement decomposing it into the following: if ...: @@ -685,9 +633,6 @@ def visit_starred_pattern(self, o: StarredPattern) -> None: def visit_name_expr(self, o: NameExpr) -> None: if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global): return - _ambv( - f"visit_name_expr {o.name} at line {o.line}, possibly_undefined={self.tracker.is_possibly_undefined(o.name)}, defined_in_different_branch={self.tracker.is_defined_in_different_branch(o.name)}, is_undefined={self.tracker.is_undefined(o.name)}" - ) if self.tracker.is_possibly_undefined(o.name): # A variable is only defined in some branches. self.variable_may_be_undefined(o.name, o) @@ -708,7 +653,6 @@ def visit_name_expr(self, o: NameExpr) -> None: # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should # be caught by this visitor. Save the ref for later, so that if we see a definition, # we know it's a used-before-definition scenario. - _ambv(f"Recording undefined ref for {o.name} at line {o.line}") self.tracker.record_undefined_ref(o) super().visit_name_expr(o)