Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Optimize dunder methods #17934

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
354d492
Optimize calls to final classes
jairov4 Oct 6, 2024
63efe61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2024
e844530
Assert rtype
jairov4 Oct 7, 2024
cdb77ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
c5d8192
Restore the placement vtable to preserve the PyObject layout
jairov4 Oct 7, 2024
489259e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
a35beb6
Restore the placement vtable to preserve the PyObject layout
jairov4 Oct 7, 2024
4ebf8b9
revert to initialize vtable always
jairov4 Oct 7, 2024
581a911
Optimize calls to binary ops with dunders
jairov4 Oct 7, 2024
0caf3c4
Optimize calls to binary ops with dunders
jairov4 Oct 7, 2024
d7f3fc8
Update test for dunders IR
jairov4 Oct 7, 2024
c151361
Fix arguments for dunders
jairov4 Oct 7, 2024
bf7d2b0
Fix test for IR dunder call
jairov4 Oct 7, 2024
572e835
Fix the case for dunders that returns NotImplemented
jairov4 Oct 7, 2024
48ba2ad
Fix the case for dunders that returns NotImplemented
jairov4 Oct 7, 2024
8b5eaf4
Allow dunders to return custom types in IR
jairov4 Oct 7, 2024
cc85adc
Improve the analysis to determine if a method returns NotImplemented
jairov4 Oct 7, 2024
edca454
Make the dunder typing strictness configurable
jairov4 Oct 8, 2024
dd6993a
Make easier use the optimization for dunders
jairov4 Oct 8, 2024
5873b3d
gen_glue_ne_method honors strict_dunders_typing
jairov4 Oct 8, 2024
a664e0a
Fix the type of the dunder argument in gen_glue_ne_method
jairov4 Oct 8, 2024
57306aa
Dup 3 lines to avoid change the tests
jairov4 Oct 8, 2024
4753a47
Revert change on irbuild-classes.test
jairov4 Oct 8, 2024
9d6e01c
Add tests for strict dunder typing mode
jairov4 Oct 8, 2024
38b3142
Fix return type for some methods
jairov4 Oct 8, 2024
71a5d34
Pass compiler options
jairov4 Oct 8, 2024
562747e
Pass compiler options
jairov4 Oct 8, 2024
8adbcbf
Introduce run-dunders-special.test
jairov4 Oct 8, 2024
04e0d98
Add test case with final decorator
jairov4 Oct 8, 2024
eb16941
Add test case with final decorator
jairov4 Oct 8, 2024
c0796a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
ec01ab3
Fix a mypy warning
jairov4 Oct 14, 2024
2908482
Fix a mypy warning
jairov4 Oct 14, 2024
7b17521
Fix a mypy warning
jairov4 Oct 14, 2024
34eb9c1
Restore finalize method
jairov4 Oct 14, 2024
2379a28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mypyc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mypyc.build import mypycify

setup(name='mypyc_output',
ext_modules=mypycify({}, opt_level="{}", debug_level="{}"),
ext_modules=mypycify({}, opt_level="{}", debug_level="{}", strict_dunder_typing={}),
)
"""

Expand All @@ -38,10 +38,11 @@ def main() -> None:

opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1")
strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0")))

setup_file = os.path.join(build_dir, "setup.py")
with open(setup_file, "w") as f:
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level))
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level, strict_dunder_typing))

# We don't use run_setup (like we do in the test suite) because it throws
# away the error code from distutils, and we don't care about the slight
Expand Down
5 changes: 5 additions & 0 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def mypycify(
skip_cgen_input: Any | None = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
) -> list[Extension]:
"""Main entry point to building using mypyc.

Expand Down Expand Up @@ -509,6 +510,9 @@ def mypycify(
should be directly #include'd instead of linked
separately in order to reduce compiler invocations.
Defaults to False in multi_file mode, True otherwise.
strict_dunder_typing: If True, force dunder methods to have the return type
of the method strictly, which can lead to more
optimization opportunities. Defaults to False.
"""

# Figure out our configuration
Expand All @@ -519,6 +523,7 @@ def mypycify(
separate=separate is not False,
target_dir=target_dir,
include_runtime_files=include_runtime_files,
strict_dunder_typing=strict_dunder_typing,
)

# Generate all the actual important C code
Expand Down
55 changes: 34 additions & 21 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
pytype_from_template_op,
type_object_op,
)
from mypyc.subtype import is_subtype


def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
Expand Down Expand Up @@ -801,30 +802,42 @@ def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None:

def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None:
"""Generate a "__ne__" method from a "__eq__" method."""
with builder.enter_method(cls, "__ne__", object_rprimitive):
rhs_arg = builder.add_argument("rhs", object_rprimitive)

# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
func_ir = cls.get_method("__eq__")
assert func_ir
eq_sig = func_ir.decl.sig
strict_typing = builder.options.strict_dunders_typing
with builder.enter_method(cls, "__ne__", eq_sig.ret_type):
rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive
rhs_arg = builder.add_argument("rhs", rhs_type)
eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line))
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)

builder.activate_block(regular_block)
retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line)
builder.add(Return(retval))
can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type)
return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive)

builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
if not strict_typing or can_return_not_implemented:
# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)
builder.activate_block(regular_block)
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))
builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
else:
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))


def load_non_ext_class(
Expand Down
14 changes: 8 additions & 6 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@


def transform_func_def(builder: IRBuilder, fdef: FuncDef) -> None:
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, builder.mapper.fdef_to_sig(fdef))
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, sig)

# If the function that was visited was a nested function, then either look it up in our
# current environment or define it if it was not already defined.
Expand All @@ -113,9 +114,8 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N


def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
func_ir, func_reg = gen_func_item(
builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func)
)
sig = builder.mapper.fdef_to_sig(dec.func, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, dec.func, dec.func.name, sig)
decorated_func: Value | None = None
if func_reg:
decorated_func = load_decorated_func(builder, dec.func, func_reg)
Expand Down Expand Up @@ -416,7 +416,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None
# Perform the function of visit_method for methods inside extension classes.
name = fdef.name
class_ir = builder.mapper.type_to_ir[cdef.info]
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
builder.functions.append(func_ir)

if is_decorated(builder, fdef):
Expand Down Expand Up @@ -481,7 +482,8 @@ def handle_non_ext_method(
) -> None:
# Perform the function of visit_method for methods inside non-extension classes.
name = fdef.name
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
assert func_reg is not None
builder.functions.append(func_ir)

Expand Down
53 changes: 42 additions & 11 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
from mypy.operators import op_methods
from mypy.operators import op_methods, unary_op_methods
from mypy.types import AnyType, TypeOfAny
from mypyc.common import (
BITMAP_BITS,
Expand Down Expand Up @@ -167,6 +167,7 @@
buf_init_item,
fast_isinstance_op,
none_object_op,
not_implemented_op,
var_object_size,
)
from mypyc.primitives.registry import (
Expand Down Expand Up @@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
if base_op in float_op_to_id:
return self.float_op(lreg, rreg, base_op, line)

dunder_op = self.dunder_op(lreg, rreg, op, line)
if dunder_op:
return dunder_op

primitive_ops_candidates = binary_ops.get(op, [])
target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line)
assert target, "Unsupported binary operation: %s" % op
return target

def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
"""
Dispatch a dunder method if applicable.
For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
due to the fact that the method could be already compiled and optimized instead of going
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
"""
ltype = lreg.type
if not isinstance(ltype, RInstance):
return None

method_name = op_methods.get(op) if rreg else unary_op_methods.get(op)
if method_name is None:
return None

if not ltype.class_ir.has_method(method_name):
return None

decl = ltype.class_ir.method_decl(method_name)
if not rreg and len(decl.sig.args) != 1:
return None

if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)):
return None

if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type):
# If the method is able to return NotImplemented, we should not optimize it.
# We can just let go so it will be handled through the python api.
return None

args = [rreg] if rreg else []
return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line)

def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
"""Check if a tagged integer is a short integer.

Expand Down Expand Up @@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
if isinstance(value, Float):
return Float(-value.value, value.line)
if isinstance(typ, RInstance):
if expr_op == "-":
method = "__neg__"
elif expr_op == "+":
method = "__pos__"
elif expr_op == "~":
method = "__invert__"
else:
method = ""
if method and typ.class_ir.has_method(method):
return self.gen_method_call(value, method, [], None, line)
result = self.dunder_op(value, None, expr_op, line)
if result is not None:
return result
call_c_ops_candidates = unary_ops.get(expr_op, [])
target = self.matching_call_c(call_c_ops_candidates, [value], line)
assert target, "Unsupported unary operation: %s" % expr_op
Expand Down
15 changes: 9 additions & 6 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType:
else:
return self.type_to_rtype(typ)

def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature:
if isinstance(fdef.type, CallableType):
arg_types = [
self.get_arg_rtype(typ, kind)
Expand Down Expand Up @@ -199,11 +199,14 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
)
]

# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive
if not strict_dunders_typing:
# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
# However, it also prevents some optimizations.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive

return FuncSignature(args, ret)

def is_native_module(self, module: str) -> bool:
Expand Down
Loading
Loading