Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Precompute set literals for "in" ops against / iteration over set literals #14409

Merged
merged 17 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...
if isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)

def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
for x in s:
if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
pass
elif isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)
else:
self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")

def visit_load_literal(self, op: LoadLiteral) -> None:
expected_type = None
if op.value is None:
Expand All @@ -271,6 +280,11 @@ def visit_load_literal(self, op: LoadLiteral) -> None:
elif isinstance(op.value, tuple):
expected_type = "builtins.tuple"
self.check_tuple_items_valid_literals(op, op.value)
elif isinstance(op.value, frozenset):
# There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
# it's a set (when it's really a frozenset).
expected_type = "builtins.set"
self.check_frozenset_items_valid_literals(op, op.value)

assert expected_type is not None, "Missed a case for LoadLiteral check"

Expand Down
5 changes: 4 additions & 1 deletion mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,9 @@ def generate_literal_tables(self) -> None:
# Descriptions of tuple literals
init_tuple = c_array_initializer(literals.encoded_tuple_values())
self.declare_global("const int []", "CPyLit_Tuple", initializer=init_tuple)
# Descriptions of frozenset literals
init_frozenset = c_array_initializer(literals.encoded_frozenset_values())
self.declare_global("const int []", "CPyLit_FrozenSet", initializer=init_frozenset)

def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None:
"""Generate the declaration and definition of the group's export struct.
Expand Down Expand Up @@ -839,7 +842,7 @@ def generate_globals_init(self, emitter: Emitter) -> None:
for symbol, fixup in self.simple_inits:
emitter.emit_line(f"{symbol} = {fixup};")

values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple"
values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple, CPyLit_FrozenSet"
emitter.emit_lines(
f"if (CPyStatics_Initialize(CPyStatics, {values}) < 0) {{", "return -1;", "}"
)
Expand Down
46 changes: 33 additions & 13 deletions mypyc/codegen/literals.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from typing import Any, Tuple, Union, cast
from typing import Any, Dict, FrozenSet, List, Tuple, Union, cast
from typing_extensions import Final

# Supported Python literal types. All tuple items must have supported
# Supported Python literal types. All tuple / frozenset items must have supported
# literal types as well, but we can't represent the type precisely.
LiteralValue = Union[str, bytes, int, bool, float, complex, Tuple[object, ...], None]

LiteralValue = Union[
str, bytes, int, bool, float, complex, Tuple[object, ...], FrozenSet[object], None
]

# Some literals are singletons and handled specially (None, False and True)
NUM_SINGLETONS: Final = 3
Expand All @@ -23,6 +24,7 @@ def __init__(self) -> None:
self.float_literals: dict[float, int] = {}
self.complex_literals: dict[complex, int] = {}
self.tuple_literals: dict[tuple[object, ...], int] = {}
self.frozenset_literals: dict[frozenset[object], int] = {}

def record_literal(self, value: LiteralValue) -> None:
"""Ensure that the literal value is available in generated code."""
Expand Down Expand Up @@ -55,6 +57,12 @@ def record_literal(self, value: LiteralValue) -> None:
for item in value:
self.record_literal(cast(Any, item))
tuple_literals[value] = len(tuple_literals)
elif isinstance(value, frozenset):
frozenset_literals = self.frozenset_literals
if value not in frozenset_literals:
for item in value:
self.record_literal(cast(Any, item))
frozenset_literals[value] = len(frozenset_literals)
else:
assert False, "invalid literal: %r" % value

Expand Down Expand Up @@ -86,6 +94,9 @@ def literal_index(self, value: LiteralValue) -> int:
n += len(self.complex_literals)
if isinstance(value, tuple):
return n + self.tuple_literals[value]
n += len(self.tuple_literals)
if isinstance(value, frozenset):
return n + self.frozenset_literals[value]
assert False, "invalid literal: %r" % value

def num_literals(self) -> int:
Expand All @@ -98,6 +109,7 @@ def num_literals(self) -> int:
+ len(self.float_literals)
+ len(self.complex_literals)
+ len(self.tuple_literals)
+ len(self.frozenset_literals)
)

# The following methods return the C encodings of literal values
Expand All @@ -119,24 +131,32 @@ def encoded_complex_values(self) -> list[str]:
return _encode_complex_values(self.complex_literals)

def encoded_tuple_values(self) -> list[str]:
"""Encode tuple values into a C array.
return self._encode_collection_values(self.tuple_literals)

def encoded_frozenset_values(self) -> List[str]:
return self._encode_collection_values(self.frozenset_literals)

def _encode_collection_values(
self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int]
) -> list[str]:
"""Encode tuple/frozenset values into a C array.

The format of the result is like this:

<number of tuples>
<length of the first tuple>
<number of collections>
<length of the first collection>
<literal index of first item>
...
<literal index of last item>
<length of the second tuple>
<length of the second collection>
...
"""
values = self.tuple_literals
value_by_index = {index: value for value, index in values.items()}
# FIXME: https://github.com/mypyc/mypyc/issues/965
value_by_index = {index: value for value, index in cast(Dict[Any, int], values).items()}
result = []
num = len(values)
result.append(str(num))
for i in range(num):
count = len(values)
result.append(str(count))
for i in range(count):
value = value_by_index[i]
result.append(str(len(value)))
for item in value:
Expand Down
9 changes: 3 additions & 6 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)

if TYPE_CHECKING:
from mypyc.codegen.literals import LiteralValue
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FuncDecl, FuncIR

Expand Down Expand Up @@ -588,7 +589,7 @@ class LoadLiteral(RegisterOp):
This is used to load a static PyObject * value corresponding to
a literal of one of the supported types.

Tuple literals must contain only valid literal values as items.
Tuple / frozenset literals must contain only valid literal values as items.

NOTE: You can use this to load boxed (Python) int objects. Use
Integer to load unboxed, tagged integers or fixed-width,
Expand All @@ -603,11 +604,7 @@ class LoadLiteral(RegisterOp):
error_kind = ERR_NEVER
is_borrowed = True

def __init__(
self,
value: None | str | bytes | bool | int | float | complex | tuple[object, ...],
rtype: RType,
) -> None:
def __init__(self, value: LiteralValue, rtype: RType) -> None:
self.value = value
self.type = rtype

Expand Down
13 changes: 12 additions & 1 deletion mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,18 @@ def visit_load_literal(self, op: LoadLiteral) -> str:
# it explicit that this is a Python object.
if isinstance(op.value, int):
prefix = "object "
return self.format("%r = %s%s", op, prefix, repr(op.value))

rvalue = repr(op.value)
if isinstance(op.value, frozenset):
# We need to generate a string representation that won't vary
# run-to-run because sets are unordered, otherwise we may get
# spurious irbuild test failures.
#
# Sorting by the item's string representation is a bit of a
# hack, but it's stable and won't cause TypeErrors.
formatted_items = [repr(i) for i in sorted(op.value, key=str)]
rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
return self.format("%r = %s%s", op, prefix, rvalue)

def visit_get_attr(self, op: GetAttr) -> str:
return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)
Expand Down
7 changes: 3 additions & 4 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
AssignmentTargetRegister,
AssignmentTargetTuple,
)
from mypyc.irbuild.util import is_constant
from mypyc.irbuild.util import bytes_from_str, is_constant
from mypyc.options import CompilerOptions
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
Expand Down Expand Up @@ -296,8 +296,7 @@ def load_bytes_from_str_literal(self, value: str) -> Value:
are stored in BytesExpr.value, whose type is 'str' not 'bytes'.
Thus we perform a special conversion here.
"""
bytes_value = bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape")
return self.builder.load_bytes(bytes_value)
return self.builder.load_bytes(bytes_from_str(value))

def load_int(self, value: int) -> Value:
return self.builder.load_int(value)
Expand Down Expand Up @@ -886,7 +885,7 @@ def get_dict_base_type(self, expr: Expression) -> Instance:
This is useful for dict subclasses like SymbolTable.
"""
target_type = get_proper_type(self.types[expr])
assert isinstance(target_type, Instance)
assert isinstance(target_type, Instance), target_type
dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict")
return map_instance_to_supertype(target_type, dict_base)

Expand Down
70 changes: 68 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Callable, cast
from typing import Callable, Sequence, cast

from mypy.nodes import (
ARG_POS,
Expand Down Expand Up @@ -55,6 +55,7 @@
ComparisonOp,
Integer,
LoadAddress,
LoadLiteral,
RaiseStandardError,
Register,
TupleGet,
Expand All @@ -63,12 +64,14 @@
)
from mypyc.ir.rtypes import (
RTuple,
bool_rprimitive,
int_rprimitive,
is_fixed_width_rtype,
is_int_rprimitive,
is_list_rprimitive,
is_none_rprimitive,
object_rprimitive,
set_rprimitive,
)
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op
Expand All @@ -86,14 +89,15 @@
tokenizer_printf_style,
)
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
from mypyc.irbuild.util import bytes_from_str
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op
from mypyc.primitives.int_ops import int_comparison_op_mapping
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
from mypyc.primitives.registry import CFunctionDescription, builtin_names
from mypyc.primitives.set_ops import set_add_op, set_update_op
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
from mypyc.primitives.str_ops import str_slice_op
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op

Expand Down Expand Up @@ -613,6 +617,51 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val
return target


def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[object] | None:
values: list[object] = []
for item in items:
const_value = constant_fold_expr(builder, item)
if const_value is not None:
values.append(const_value)
continue

if isinstance(item, RefExpr):
if item.fullname == "builtins.None":
values.append(None)
elif item.fullname == "builtins.True":
values.append(True)
elif item.fullname == "builtins.False":
values.append(False)
elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)):
# constant_fold_expr() doesn't handle these (yet?)
v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value
values.append(v)
elif isinstance(item, TupleExpr):
tuple_values = set_literal_values(builder, item.items)
if tuple_values is not None:
values.append(tuple(tuple_values))

if len(values) != len(items):
# Bail if not all items can be converted into values.
return None
return values


def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None:
"""Try to pre-compute a frozenset literal during module initialization.

Return None if it's not possible.

Only references to "simple" final variables, tuple literals (with items that
are themselves supported), and other non-container literals are supported.
"""
values = set_literal_values(builder, s.items)
if values is not None:
return builder.add(LoadLiteral(frozenset(values), set_rprimitive))

return None


def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
# x in (...)/[...]
# x not in (...)/[...]
Expand Down Expand Up @@ -666,6 +715,23 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
else:
return builder.true()

# x in {...}
# x not in {...}
if (
first_op in ("in", "not in")
and len(e.operators) == 1
and isinstance(e.operands[1], SetExpr)
):
set_literal = precompute_set_literal(builder, e.operands[1])
if set_literal is not None:
lhs = e.operands[0]
result = builder.builder.call_c(
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
)
if first_op == "not in":
return builder.unary_op(result, "not", e.line)
return result

if len(e.operators) == 1:
# Special some common simple cases
if first_op in ("is", "is not"):
Expand Down
15 changes: 13 additions & 2 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Lvalue,
MemberExpr,
RefExpr,
SetExpr,
TupleExpr,
TypeAlias,
)
Expand Down Expand Up @@ -469,12 +470,22 @@ def make_for_loop_generator(
for_dict_gen.init(expr_reg, target_type)
return for_dict_gen

iterable_expr_reg: Value | None = None
if isinstance(expr, SetExpr):
# Special case "for x in <set literal>".
from mypyc.irbuild.expression import precompute_set_literal

set_literal = precompute_set_literal(builder, expr)
if set_literal is not None:
iterable_expr_reg = set_literal

# Default to a generic for loop.
expr_reg = builder.accept(expr)
if iterable_expr_reg is None:
iterable_expr_reg = builder.accept(expr)
for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested)
item_type = builder._analyze_iterable_item_type(expr)
item_rtype = builder.type_to_rtype(item_type)
for_obj.init(expr_reg, item_rtype)
for_obj.init(iterable_expr_reg, item_rtype)
return for_obj


Expand Down
10 changes: 10 additions & 0 deletions mypyc/irbuild/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,13 @@ def is_constant(e: Expression) -> bool:
)
)
)


def bytes_from_str(value: str) -> bytes:
"""Convert a string representing bytes into actual bytes.

This is needed because the literal characters of BytesExpr (the
characters inside b'') are stored in BytesExpr.value, whose type is
'str' not 'bytes'.
"""
return bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape")
Loading