diff --git a/mypyc/__main__.py b/mypyc/__main__.py index a3b9d21bc65a..653199e0fb55 100644 --- a/mypyc/__main__.py +++ b/mypyc/__main__.py @@ -24,7 +24,7 @@ from mypyc.build import mypycify setup(name='mypyc_output', - ext_modules=mypycify({}, opt_level="{}", debug_level="{}"), + ext_modules=mypycify({}, opt_level="{}", debug_level="{}", strict_dunder_typing={}), ) """ @@ -38,10 +38,11 @@ def main() -> None: opt_level = os.getenv("MYPYC_OPT_LEVEL", "3") debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1") + strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0"))) setup_file = os.path.join(build_dir, "setup.py") with open(setup_file, "w") as f: - f.write(setup_format.format(sys.argv[1:], opt_level, debug_level)) + f.write(setup_format.format(sys.argv[1:], opt_level, debug_level, strict_dunder_typing)) # We don't use run_setup (like we do in the test suite) because it throws # away the error code from distutils, and we don't care about the slight diff --git a/mypyc/build.py b/mypyc/build.py index 485803acba46..6d59113ef872 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -470,6 +470,7 @@ def mypycify( skip_cgen_input: Any | None = None, target_dir: str | None = None, include_runtime_files: bool | None = None, + strict_dunder_typing: bool = False, ) -> list[Extension]: """Main entry point to building using mypyc. @@ -509,6 +510,9 @@ def mypycify( should be directly #include'd instead of linked separately in order to reduce compiler invocations. Defaults to False in multi_file mode, True otherwise. + strict_dunder_typing: If True, force dunder methods to have the return type + of the method strictly, which can lead to more + optimization opportunities. Defaults to False. """ # Figure out our configuration @@ -519,6 +523,7 @@ def mypycify( separate=separate is not False, target_dir=target_dir, include_runtime_files=include_runtime_files, + strict_dunder_typing=strict_dunder_typing, ) # Generate all the actual important C code diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index bcc9594adcb9..6ff6308b81d8 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -81,6 +81,7 @@ pytype_from_template_op, type_object_op, ) +from mypyc.subtype import is_subtype def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: @@ -801,30 +802,42 @@ def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None: def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None: """Generate a "__ne__" method from a "__eq__" method.""" - with builder.enter_method(cls, "__ne__", object_rprimitive): - rhs_arg = builder.add_argument("rhs", object_rprimitive) - - # If __eq__ returns NotImplemented, then __ne__ should also - not_implemented_block, regular_block = BasicBlock(), BasicBlock() + func_ir = cls.get_method("__eq__") + assert func_ir + eq_sig = func_ir.decl.sig + strict_typing = builder.options.strict_dunders_typing + with builder.enter_method(cls, "__ne__", eq_sig.ret_type): + rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive + rhs_arg = builder.add_argument("rhs", rhs_type) eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line)) - not_implemented = builder.add( - LoadAddress(not_implemented_op.type, not_implemented_op.src, line) - ) - builder.add( - Branch( - builder.translate_is_op(eqval, not_implemented, "is", line), - not_implemented_block, - regular_block, - Branch.BOOL, - ) - ) - builder.activate_block(regular_block) - retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line) - builder.add(Return(retval)) + can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type) + return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive) - builder.activate_block(not_implemented_block) - builder.add(Return(not_implemented)) + if not strict_typing or can_return_not_implemented: + # If __eq__ returns NotImplemented, then __ne__ should also + not_implemented_block, regular_block = BasicBlock(), BasicBlock() + not_implemented = builder.add( + LoadAddress(not_implemented_op.type, not_implemented_op.src, line) + ) + builder.add( + Branch( + builder.translate_is_op(eqval, not_implemented, "is", line), + not_implemented_block, + regular_block, + Branch.BOOL, + ) + ) + builder.activate_block(regular_block) + rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive + retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line) + builder.add(Return(retval)) + builder.activate_block(not_implemented_block) + builder.add(Return(not_implemented)) + else: + rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive + retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line) + builder.add(Return(retval)) def load_non_ext_class( diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index c985e88b0e0c..f9d55db50f27 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -96,7 +96,8 @@ def transform_func_def(builder: IRBuilder, fdef: FuncDef) -> None: - func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, builder.mapper.fdef_to_sig(fdef)) + sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing) + func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, sig) # If the function that was visited was a nested function, then either look it up in our # current environment or define it if it was not already defined. @@ -113,9 +114,8 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: - func_ir, func_reg = gen_func_item( - builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func) - ) + sig = builder.mapper.fdef_to_sig(dec.func, builder.options.strict_dunders_typing) + func_ir, func_reg = gen_func_item(builder, dec.func, dec.func.name, sig) decorated_func: Value | None = None if func_reg: decorated_func = load_decorated_func(builder, dec.func, func_reg) @@ -416,7 +416,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None # Perform the function of visit_method for methods inside extension classes. name = fdef.name class_ir = builder.mapper.type_to_ir[cdef.info] - func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef) + sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing) + func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef) builder.functions.append(func_ir) if is_decorated(builder, fdef): @@ -481,7 +482,8 @@ def handle_non_ext_method( ) -> None: # Perform the function of visit_method for methods inside non-extension classes. name = fdef.name - func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef) + sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing) + func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef) assert func_reg is not None builder.functions.append(func_ir) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index c98136ce06d2..cb23a74c69c6 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -14,7 +14,7 @@ from mypy.argmap import map_actuals_to_formals from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind -from mypy.operators import op_methods +from mypy.operators import op_methods, unary_op_methods from mypy.types import AnyType, TypeOfAny from mypyc.common import ( BITMAP_BITS, @@ -167,6 +167,7 @@ buf_init_item, fast_isinstance_op, none_object_op, + not_implemented_op, var_object_size, ) from mypyc.primitives.registry import ( @@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: if base_op in float_op_to_id: return self.float_op(lreg, rreg, base_op, line) + dunder_op = self.dunder_op(lreg, rreg, op, line) + if dunder_op: + return dunder_op + primitive_ops_candidates = binary_ops.get(op, []) target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line) assert target, "Unsupported binary operation: %s" % op return target + def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None: + """ + Dispatch a dunder method if applicable. + For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance + due to the fact that the method could be already compiled and optimized instead of going + all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL). + """ + ltype = lreg.type + if not isinstance(ltype, RInstance): + return None + + method_name = op_methods.get(op) if rreg else unary_op_methods.get(op) + if method_name is None: + return None + + if not ltype.class_ir.has_method(method_name): + return None + + decl = ltype.class_ir.method_decl(method_name) + if not rreg and len(decl.sig.args) != 1: + return None + + if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)): + return None + + if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type): + # If the method is able to return NotImplemented, we should not optimize it. + # We can just let go so it will be handled through the python api. + return None + + args = [rreg] if rreg else [] + return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line) + def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value: """Check if a tagged integer is a short integer. @@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value: if isinstance(value, Float): return Float(-value.value, value.line) if isinstance(typ, RInstance): - if expr_op == "-": - method = "__neg__" - elif expr_op == "+": - method = "__pos__" - elif expr_op == "~": - method = "__invert__" - else: - method = "" - if method and typ.class_ir.has_method(method): - return self.gen_method_call(value, method, [], None, line) + result = self.dunder_op(value, None, expr_op, line) + if result is not None: + return result call_c_ops_candidates = unary_ops.get(expr_op, []) target = self.matching_call_c(call_c_ops_candidates, [value], line) assert target, "Unsupported unary operation: %s" % expr_op diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 90ce0e16c741..78e54aceed3d 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -160,7 +160,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType: else: return self.type_to_rtype(typ) - def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature: + def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature: if isinstance(fdef.type, CallableType): arg_types = [ self.get_arg_rtype(typ, kind) @@ -199,11 +199,14 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature: ) ] - # We force certain dunder methods to return objects to support letting them - # return NotImplemented. It also avoids some pointless boxing and unboxing, - # since tp_richcompare needs an object anyways. - if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"): - ret = object_rprimitive + if not strict_dunders_typing: + # We force certain dunder methods to return objects to support letting them + # return NotImplemented. It also avoids some pointless boxing and unboxing, + # since tp_richcompare needs an object anyways. + # However, it also prevents some optimizations. + if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"): + ret = object_rprimitive + return FuncSignature(args, ret) def is_native_module(self, module: str) -> bool: diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 40a40b79df49..bed9a1684326 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -99,9 +99,11 @@ def build_type_map( for module, cdef in classes: with catch_errors(module.path, cdef.line): if mapper.type_to_ir[cdef.info].is_ext_class: - prepare_class_def(module.path, module.fullname, cdef, errors, mapper) + prepare_class_def(module.path, module.fullname, cdef, errors, mapper, options) else: - prepare_non_ext_class_def(module.path, module.fullname, cdef, errors, mapper) + prepare_non_ext_class_def( + module.path, module.fullname, cdef, errors, mapper, options + ) # Prepare implicit attribute accessors as needed if an attribute overrides a property. for module, cdef in classes: @@ -114,7 +116,7 @@ def build_type_map( # is conditionally defined. for module in modules: for func in get_module_func_defs(module): - prepare_func_def(module.fullname, None, func, mapper) + prepare_func_def(module.fullname, None, func, mapper, options) # TODO: what else? # Check for incompatible attribute definitions that were not @@ -168,27 +170,39 @@ def get_module_func_defs(module: MypyFile) -> Iterable[FuncDef]: def prepare_func_def( - module_name: str, class_name: str | None, fdef: FuncDef, mapper: Mapper + module_name: str, + class_name: str | None, + fdef: FuncDef, + mapper: Mapper, + options: CompilerOptions, ) -> FuncDecl: kind = ( FUNC_STATICMETHOD if fdef.is_static else (FUNC_CLASSMETHOD if fdef.is_class else FUNC_NORMAL) ) - decl = FuncDecl(fdef.name, class_name, module_name, mapper.fdef_to_sig(fdef), kind) + sig = mapper.fdef_to_sig(fdef, options.strict_dunders_typing) + decl = FuncDecl(fdef.name, class_name, module_name, sig, kind) mapper.func_to_decl[fdef] = decl return decl def prepare_method_def( - ir: ClassIR, module_name: str, cdef: ClassDef, mapper: Mapper, node: FuncDef | Decorator + ir: ClassIR, + module_name: str, + cdef: ClassDef, + mapper: Mapper, + node: FuncDef | Decorator, + options: CompilerOptions, ) -> None: if isinstance(node, FuncDef): - ir.method_decls[node.name] = prepare_func_def(module_name, cdef.name, node, mapper) + ir.method_decls[node.name] = prepare_func_def( + module_name, cdef.name, node, mapper, options + ) elif isinstance(node, Decorator): # TODO: do something about abstract methods here. Currently, they are handled just like # normal methods. - decl = prepare_func_def(module_name, cdef.name, node.func, mapper) + decl = prepare_func_def(module_name, cdef.name, node.func, mapper, options) if not node.decorators: ir.method_decls[node.name] = decl elif isinstance(node.decorators[0], MemberExpr) and node.decorators[0].name == "setter": @@ -241,7 +255,12 @@ def can_subclass_builtin(builtin_base: str) -> bool: def prepare_class_def( - path: str, module_name: str, cdef: ClassDef, errors: Errors, mapper: Mapper + path: str, + module_name: str, + cdef: ClassDef, + errors: Errors, + mapper: Mapper, + options: CompilerOptions, ) -> None: """Populate the interface-level information in a class IR. @@ -308,7 +327,7 @@ def prepare_class_def( ir.mro = mro ir.base_mro = base_mro - prepare_methods_and_attributes(cdef, ir, path, module_name, errors, mapper) + prepare_methods_and_attributes(cdef, ir, path, module_name, errors, mapper, options) prepare_init_method(cdef, ir, module_name, mapper) for base in bases: @@ -320,7 +339,13 @@ def prepare_class_def( def prepare_methods_and_attributes( - cdef: ClassDef, ir: ClassIR, path: str, module_name: str, errors: Errors, mapper: Mapper + cdef: ClassDef, + ir: ClassIR, + path: str, + module_name: str, + errors: Errors, + mapper: Mapper, + options: CompilerOptions, ) -> None: """Populate attribute and method declarations.""" info = cdef.info @@ -342,20 +367,20 @@ def prepare_methods_and_attributes( add_setter_declaration(ir, name, attr_rtype, module_name) ir.attributes[name] = attr_rtype elif isinstance(node.node, (FuncDef, Decorator)): - prepare_method_def(ir, module_name, cdef, mapper, node.node) + prepare_method_def(ir, module_name, cdef, mapper, node.node, options) elif isinstance(node.node, OverloadedFuncDef): # Handle case for property with both a getter and a setter if node.node.is_property: if is_valid_multipart_property_def(node.node): for item in node.node.items: - prepare_method_def(ir, module_name, cdef, mapper, item) + prepare_method_def(ir, module_name, cdef, mapper, item, options) else: errors.error("Unsupported property decorator semantics", path, cdef.line) # Handle case for regular function overload else: assert node.node.impl - prepare_method_def(ir, module_name, cdef, mapper, node.node.impl) + prepare_method_def(ir, module_name, cdef, mapper, node.node.impl, options) if ir.builtin_base: ir.attributes.clear() @@ -440,7 +465,7 @@ def prepare_init_method(cdef: ClassDef, ir: ClassIR, module_name: str, mapper: M # Set up a constructor decl init_node = cdef.info["__init__"].node if not ir.is_trait and not ir.builtin_base and isinstance(init_node, FuncDef): - init_sig = mapper.fdef_to_sig(init_node) + init_sig = mapper.fdef_to_sig(init_node, True) defining_ir = mapper.type_to_ir.get(init_node.info) # If there is a nontrivial __init__ that wasn't defined in an @@ -467,24 +492,29 @@ def prepare_init_method(cdef: ClassDef, ir: ClassIR, module_name: str, mapper: M def prepare_non_ext_class_def( - path: str, module_name: str, cdef: ClassDef, errors: Errors, mapper: Mapper + path: str, + module_name: str, + cdef: ClassDef, + errors: Errors, + mapper: Mapper, + options: CompilerOptions, ) -> None: ir = mapper.type_to_ir[cdef.info] info = cdef.info for node in info.names.values(): if isinstance(node.node, (FuncDef, Decorator)): - prepare_method_def(ir, module_name, cdef, mapper, node.node) + prepare_method_def(ir, module_name, cdef, mapper, node.node, options) elif isinstance(node.node, OverloadedFuncDef): # Handle case for property with both a getter and a setter if node.node.is_property: if not is_valid_multipart_property_def(node.node): errors.error("Unsupported property decorator semantics", path, cdef.line) for item in node.node.items: - prepare_method_def(ir, module_name, cdef, mapper, item) + prepare_method_def(ir, module_name, cdef, mapper, item, options) # Handle case for regular function overload else: - prepare_method_def(ir, module_name, cdef, mapper, get_func_def(node.node)) + prepare_method_def(ir, module_name, cdef, mapper, get_func_def(node.node), options) if any(cls in mapper.type_to_ir and mapper.type_to_ir[cls].is_ext_class for cls in info.mro): errors.error( diff --git a/mypyc/options.py b/mypyc/options.py index 5f0cf12aeefe..24e68163bb11 100644 --- a/mypyc/options.py +++ b/mypyc/options.py @@ -14,6 +14,7 @@ def __init__( include_runtime_files: bool | None = None, capi_version: tuple[int, int] | None = None, python_version: tuple[int, int] | None = None, + strict_dunder_typing: bool = False, ) -> None: self.strip_asserts = strip_asserts self.multi_file = multi_file @@ -30,3 +31,10 @@ def __init__( # features are used. self.capi_version = capi_version or sys.version_info[:2] self.python_version = python_version + # Make possible to inline dunder methods in the generated code. + # Typically, the convention is the dunder methods can return `NotImplemented` + # even when its return type is just `bool`. + # By enabling this option, this convention is no longer valid and the dunder + # will assume the return type of the method strictly, which can lead to + # more optimization opportunities. + self.strict_dunders_typing = strict_dunder_typing diff --git a/mypyc/test-data/run-dunders-special.test b/mypyc/test-data/run-dunders-special.test new file mode 100644 index 000000000000..cd02cca65eef --- /dev/null +++ b/mypyc/test-data/run-dunders-special.test @@ -0,0 +1,9 @@ +[case testDundersNotImplemented] +# This case is special because it tests the behavior of NotImplemented +# used in a typed function which return type is bool. +# This is a convention that can be overriden by the user. +class UsesNotImplemented: + def __eq__(self, b: object) -> bool: + return NotImplemented + +assert UsesNotImplemented() != object() diff --git a/mypyc/test-data/run-dunders.test b/mypyc/test-data/run-dunders.test index 2845187de2c3..b8fb13c9dcec 100644 --- a/mypyc/test-data/run-dunders.test +++ b/mypyc/test-data/run-dunders.test @@ -32,10 +32,6 @@ class BoxedThing: class Subclass2(BoxedThing): pass -class UsesNotImplemented: - def __eq__(self, b: object) -> bool: - return NotImplemented - def index_into(x : Any, y : Any) -> Any: return x[y] @@ -81,8 +77,6 @@ assert is_truthy(Item('a')) assert not is_truthy(Subclass1('')) assert is_truthy(Subclass1('a')) -assert UsesNotImplemented() != object() - internal_index_into() [out] 7 7 @@ -943,3 +937,31 @@ def test_errors() -> None: pow(ForwardNotImplemented(), Child(), 3) # type: ignore with assertRaises(TypeError, "unsupported operand type(s) for ** or pow(): 'ForwardModRequired' and 'int'"): ForwardModRequired()**3 # type: ignore + +[case testDundersWithFinal] +from typing import final +class A: + def __init__(self, x: int) -> None: + self.x = x + + def __add__(self, y: int) -> int: + return self.x + y + + def __lt__(self, x: 'A') -> bool: + return self.x < x.x + +@final +class B(A): + def __add__(self, y: int) -> int: + return self.x + y + 1 + + def __lt__(self, x: 'A') -> bool: + return self.x < x.x + 1 + +def test_final() -> None: + a = A(5) + b = B(5) + assert a + 3 == 8 + assert b + 3 == 9 + assert (a < A(5)) is False + assert (b < A(5)) is True diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 668e5b124841..dd3c79da7b9b 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -63,6 +63,7 @@ "run-bench.test", "run-mypy-sim.test", "run-dunders.test", + "run-dunders-special.test", "run-singledispatch.test", "run-attrs.test", "run-python37.test", @@ -140,6 +141,7 @@ class TestRun(MypycDataSuite): optional_out = True multi_file = False separate = False # If True, using separate (incremental) compilation + strict_dunder_typing = False def run_case(self, testcase: DataDrivenTestCase) -> None: # setup.py wants to be run from the root directory of the package, which we accommodate @@ -232,7 +234,11 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> groups = construct_groups(sources, separate, len(module_names) > 1) try: - compiler_options = CompilerOptions(multi_file=self.multi_file, separate=self.separate) + compiler_options = CompilerOptions( + multi_file=self.multi_file, + separate=self.separate, + strict_dunder_typing=self.strict_dunder_typing, + ) result = emitmodule.parse_and_typecheck( sources=sources, options=options, @@ -401,6 +407,14 @@ class TestRunSeparate(TestRun): files = ["run-multimodule.test", "run-mypy-sim.test"] +class TestRunStrictDunderTyping(TestRun): + """Run the tests with strict dunder typing.""" + + strict_dunder_typing = True + test_name_suffix = "_dunder_typing" + files = ["run-dunders.test", "run-floats.test"] + + def fix_native_line_number(message: str, fnam: str, delta: int) -> str: """Update code locations in test case output to point to the .test file.