Skip to content

Commit

Permalink
[used before def] improve handling of global definitions in local sco…
Browse files Browse the repository at this point in the history
…pes (#14517)

While working on #14483, we discovered that variable inheritance didn't
work quite right. In particular, functions would inherit variables from
outer scope. On the surface, this is what you want but actually, they
only inherit the scope if there isn't a colliding definition within that
scope.

Here's an example:
```python
class c: pass

def f0() -> None:
    s = c()  # UnboundLocalError is raised when this code is executed.
    class c: pass

def f1() -> None:
    s = c()  # No error.
```
This PR also fixes issues with builtins (exactly the same example as
above but instead of `c` we have a builtin).

Fixes #14213 (as much as is reasonable to do)
  • Loading branch information
ilinum authored Mar 1, 2023
1 parent c245e91 commit a618110
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 78 deletions.
41 changes: 27 additions & 14 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def copy(self) -> BranchState:


class BranchStatement:
def __init__(self, initial_state: BranchState) -> None:
def __init__(self, initial_state: BranchState | None = None) -> None:
if initial_state is None:
initial_state = BranchState()
self.initial_state = initial_state
self.branches: list[BranchState] = [
BranchState(
Expand Down Expand Up @@ -171,7 +173,7 @@ class ScopeType(Enum):
Global = 1
Class = 2
Func = 3
Generator = 3
Generator = 4


class Scope:
Expand Down Expand Up @@ -199,7 +201,7 @@ class DefinedVariableTracker:

def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())], ScopeType.Global)]
self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)]
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
# in things like try/except/finally statements.
self.disable_branch_skip = False
Expand All @@ -216,9 +218,11 @@ def _scope(self) -> Scope:

def enter_scope(self, scope_type: ScopeType) -> None:
assert len(self._scope().branch_stmts) > 0
self.scopes.append(
Scope([BranchStatement(self._scope().branch_stmts[-1].branches[-1])], scope_type)
)
initial_state = None
if scope_type == ScopeType.Generator:
# Generators are special because they inherit the outer scope.
initial_state = self._scope().branch_stmts[-1].branches[-1]
self.scopes.append(Scope([BranchStatement(initial_state)], scope_type))

def exit_scope(self) -> None:
self.scopes.pop()
Expand Down Expand Up @@ -342,13 +346,15 @@ def variable_may_be_undefined(self, name: str, context: Context) -> None:
def process_definition(self, name: str) -> None:
# Was this name previously used? If yes, it's a used-before-definition error.
if not self.tracker.in_scope(ScopeType.Class):
# Errors in class scopes are caught by the semantic analyzer.
refs = self.tracker.pop_undefined_ref(name)
for ref in refs:
if self.loops:
self.variable_may_be_undefined(name, ref)
else:
self.var_used_before_def(name, ref)
else:
# Errors in class scopes are caught by the semantic analyzer.
pass
self.tracker.record_definition(name)

def visit_global_decl(self, o: GlobalDecl) -> None:
Expand Down Expand Up @@ -415,17 +421,24 @@ def visit_match_stmt(self, o: MatchStmt) -> None:

def visit_func_def(self, o: FuncDef) -> None:
self.process_definition(o.name)
self.tracker.enter_scope(ScopeType.Func)
super().visit_func_def(o)
self.tracker.exit_scope()

def visit_func(self, o: FuncItem) -> None:
if o.is_dynamic() and not self.options.check_untyped_defs:
return
if o.arguments is not None:
for arg in o.arguments:
self.tracker.record_definition(arg.variable.name)
super().visit_func(o)

args = o.arguments or []
# Process initializers (defaults) outside the function scope.
for arg in args:
if arg.initializer is not None:
arg.initializer.accept(self)

self.tracker.enter_scope(ScopeType.Func)
for arg in args:
self.process_definition(arg.variable.name)
super().visit_var(arg.variable)
o.body.accept(self)
self.tracker.exit_scope()

def visit_generator_expr(self, o: GeneratorExpr) -> None:
self.tracker.enter_scope(ScopeType.Generator)
Expand Down Expand Up @@ -603,7 +616,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
super().visit_starred_pattern(o)

def visit_name_expr(self, o: NameExpr) -> None:
if o.name in self.builtins:
if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
return
if self.tracker.is_possibly_undefined(o.name):
# A variable is only defined in some branches.
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/run-sets.test
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def test_in_set() -> None:
assert main_set(item), f"{item!r} should be in set_main"
assert not main_negated_set(item), item

assert non_final_name_set(non_const)
global non_const
assert non_final_name_set(non_const)
non_const = "updated"
assert non_final_name_set("updated")

Expand Down
39 changes: 19 additions & 20 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -491,62 +491,61 @@ if int():

[case testDefaultArgumentExpressions]
import typing
class B: pass
class A: pass

def f(x: 'A' = A()) -> None:
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
a = x # type: A

class B: pass
class A: pass
[out]

[case testDefaultArgumentExpressions2]
import typing
def f(x: 'A' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
a = x # type: A

class B: pass
class A: pass

def f(x: 'A' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
a = x # type: A
[case testDefaultArgumentExpressionsGeneric]
from typing import TypeVar
T = TypeVar('T', bound='A')
def f(x: T = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "T")
b = x # type: B # E: Incompatible types in assignment (expression has type "T", variable has type "B")
a = x # type: A

class B: pass
class A: pass

def f(x: T = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "T")
b = x # type: B # E: Incompatible types in assignment (expression has type "T", variable has type "B")
a = x # type: A
[case testDefaultArgumentsWithSubtypes]
import typing
class A: pass
class B(A): pass

def f(x: 'B' = A()) -> None: # E: Incompatible default for argument "x" (default has type "A", argument has type "B")
pass
def g(x: 'A' = B()) -> None:
pass

class A: pass
class B(A): pass
[out]

[case testMultipleDefaultArgumentExpressions]
import typing
class A: pass
class B: pass

def f(x: 'A' = B(), y: 'B' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
pass
def h(x: 'A' = A(), y: 'B' = B()) -> None:
pass

class A: pass
class B: pass
[out]

[case testMultipleDefaultArgumentExpressions2]
import typing
def g(x: 'A' = A(), y: 'B' = A()) -> None: # E: Incompatible default for argument "y" (default has type "A", argument has type "B")
pass

class A: pass
class B: pass

def g(x: 'A' = A(), y: 'B' = A()) -> None: # E: Incompatible default for argument "y" (default has type "A", argument has type "B")
pass
[out]

[case testDefaultArgumentsAndSignatureAsComment]
Expand Down Expand Up @@ -2612,7 +2611,7 @@ def f() -> int: ...
[case testLambdaDefaultTypeErrors]
lambda a=(1 + 'asdf'): a # E: Unsupported operand types for + ("int" and "str")
lambda a=nonsense: a # E: Name "nonsense" is not defined
def f(x: int = i): # E: Name "i" is not defined # E: Name "i" is used before definition
def f(x: int = i): # E: Name "i" is not defined
i = 42

[case testRevealTypeOfCallExpressionReturningNoneWorks]
Expand Down
Loading

0 comments on commit a618110

Please sign in to comment.