Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Optimize __(a)enter__/__(a)exit__ paths for native case #14530

Merged
merged 1 commit into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Integer,
LoadAddress,
LoadErrorValue,
MethodCall,
RaiseStandardError,
Register,
Return,
Expand All @@ -61,6 +62,7 @@
RInstance,
exc_rtuple,
is_tagged,
none_rprimitive,
object_pointer_rprimitive,
object_rprimitive,
)
Expand Down Expand Up @@ -657,14 +659,45 @@ def transform_with(
al = "a" if is_async else ""

mgr_v = builder.accept(expr)
typ = builder.call_c(type_op, [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)
is_native = isinstance(mgr_v.type, RInstance)
if is_native:
value = builder.add(MethodCall(mgr_v, f"__{al}enter__", args=[], line=line))
exit_ = None
else:
typ = builder.call_c(type_op, [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)

mgr = builder.maybe_spill(mgr_v)
exc = builder.maybe_spill_assignable(builder.true())
if is_async:
value = emit_await(builder, value, line)

def maybe_natively_call_exit(exc_info: bool) -> Value:
if exc_info:
args = get_sys_exc_info(builder)
else:
none = builder.none_object()
args = [none, none, none]

if is_native:
assert isinstance(mgr_v.type, RInstance)
exit_val = builder.gen_method_call(
builder.read(mgr),
f"__{al}exit__",
arg_values=args,
line=line,
result_type=none_rprimitive,
)
else:
assert exit_ is not None
exit_val = builder.py_call(builder.read(exit_), [builder.read(mgr)] + args, line)

if is_async:
return emit_await(builder, exit_val, line)
else:
return exit_val

def try_body() -> None:
if target:
builder.assign(builder.get_assignment_target(target), value, line)
Expand All @@ -673,13 +706,7 @@ def try_body() -> None:
def except_body() -> None:
builder.assign(exc, builder.false(), line)
out_block, reraise_block = BasicBlock(), BasicBlock()
exit_val = builder.py_call(
builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line
)
if is_async:
exit_val = emit_await(builder, exit_val, line)

builder.add_bool_branch(exit_val, out_block, reraise_block)
builder.add_bool_branch(maybe_natively_call_exit(exc_info=True), out_block, reraise_block)
builder.activate_block(reraise_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())
Expand All @@ -689,13 +716,8 @@ def finally_body() -> None:
out_block, exit_block = BasicBlock(), BasicBlock()
builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL))
builder.activate_block(exit_block)
none = builder.none_object()
exit_val = builder.py_call(
builder.read(exit_), [builder.read(mgr), none, none, none], line
)
if is_async:
emit_await(builder, exit_val, line)

maybe_natively_call_exit(exc_info=False)
builder.goto_and_activate(out_block)

transform_try_finally_stmt(
Expand Down
105 changes: 105 additions & 0 deletions mypyc/test-data/irbuild-try.test
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,108 @@ L19:
L20:
return 1

[case testWithNativeSimple]
class DummyContext:
def __enter__(self) -> None:
pass
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass

def foo(x: DummyContext) -> None:
with x:
print('hello')
[out]
def DummyContext.__enter__(self):
self :: __main__.DummyContext
L0:
return 1
def DummyContext.__exit__(self, exc_type, exc_val, exc_tb):
self :: __main__.DummyContext
exc_type, exc_val, exc_tb :: object
L0:
return 1
def foo(x):
x :: __main__.DummyContext
r0 :: None
r1 :: bool
r2 :: str
r3 :: object
r4 :: str
r5, r6 :: object
r7, r8 :: tuple[object, object, object]
r9, r10, r11 :: object
r12 :: None
r13 :: object
r14 :: int32
r15 :: bit
r16 :: bool
r17 :: bit
r18, r19, r20 :: tuple[object, object, object]
r21 :: object
r22 :: None
r23 :: bit
L0:
r0 = x.__enter__()
r1 = 1
L1:
L2:
r2 = 'hello'
r3 = builtins :: module
r4 = 'print'
r5 = CPyObject_GetAttr(r3, r4)
r6 = PyObject_CallFunctionObjArgs(r5, r2, 0)
goto L8
L3: (handler for L2)
r7 = CPy_CatchError()
r1 = 0
r8 = CPy_GetExcInfo()
r9 = r8[0]
r10 = r8[1]
r11 = r8[2]
r12 = x.__exit__(r9, r10, r11)
r13 = box(None, r12)
r14 = PyObject_IsTrue(r13)
r15 = r14 >= 0 :: signed
r16 = truncate r14: int32 to builtins.bool
if r16 goto L5 else goto L4 :: bool
L4:
CPy_Reraise()
unreachable
L5:
L6:
CPy_RestoreExcInfo(r7)
goto L8
L7: (handler for L3, L4, L5)
CPy_RestoreExcInfo(r7)
r17 = CPy_KeepPropagating()
unreachable
L8:
L9:
L10:
r18 = <error> :: tuple[object, object, object]
r19 = r18
goto L12
L11: (handler for L1, L6, L7, L8)
r20 = CPy_CatchError()
r19 = r20
L12:
if r1 goto L13 else goto L14 :: bool
L13:
r21 = load_address _Py_NoneStruct
r22 = x.__exit__(r21, r21, r21)
L14:
if is_error(r19) goto L16 else goto L15
L15:
CPy_Reraise()
unreachable
L16:
goto L20
L17: (handler for L12, L13, L14, L15)
if is_error(r19) goto L19 else goto L18
L18:
CPy_RestoreExcInfo(r19)
L19:
r23 = CPy_KeepPropagating()
unreachable
L20:
return 1
17 changes: 17 additions & 0 deletions mypyc/test-data/run-generators.test
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,20 @@ def list_comp() -> List[int]:
[file driver.py]
from native import list_comp
assert list_comp() == [5]

[case testWithNative]
class DummyContext:
def __init__(self) -> None:
self.x = 0

def __enter__(self) -> None:
self.x += 1

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self.x -= 1

def test_basic() -> None:
context = DummyContext()
with context:
assert context.x == 1
assert context.x == 0
30 changes: 30 additions & 0 deletions mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,33 @@ i = b"foo"

def test_redefinition() -> None:
assert i == b"foo"

[case testWithNative]
class DummyContext:
def __init__(self):
self.c = 0
def __enter__(self) -> None:
self.c += 1
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.c -= 1

def test_dummy_context() -> None:
c = DummyContext()
with c:
assert c.c == 1
assert c.c == 0

[case testWithNativeVarArgs]
class DummyContext:
def __init__(self):
self.c = 0
def __enter__(self) -> None:
self.c += 1
def __exit__(self, *args: object) -> None:
self.c -= 1

def test_dummy_context() -> None:
c = DummyContext()
with c:
assert c.c == 1
assert c.c == 0