Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added is identity comparison for numpy records and added tests as is_comparison_tests.py #9384

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3286,10 +3286,19 @@ def array_is_impl(a, b):

return context.compile_internal(builder, array_is_impl, sig, args)


@lower_builtin(operator.is_, types.Record, types.Record)
def record_is(context, builder, sig, args):
aty, bty = sig.args
if aty != bty:
return cgutils.false_bit

return builder.icmp_unsigned('==', args[0], args[1])
guilhermeleobas marked this conversation as resolved.
Show resolved Hide resolved


# ------------------------------------------------------------------------------
# Hash


@overload_attribute(types.Array, "__hash__")
def ol_array_hash(arr):
return lambda arr: None
Expand Down
57 changes: 56 additions & 1 deletion numba/tests/test_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from numba.core.errors import TypingError, NumbaValueError
from numba.np.numpy_support import as_dtype, numpy_version
from numba.tests.support import TestCase, MemoryLeakMixin, needs_blas
from numba.core import ir, cgutils
from np.arrayobj import record_is
guilhermeleobas marked this conversation as resolved.
Show resolved Hide resolved

TIMEDELTA_M = 'timedelta64[M]'
TIMEDELTA_Y = 'timedelta64[Y]'
Expand Down Expand Up @@ -1763,7 +1765,7 @@ def test_array_ctor_with_dtype_arg(self):
np.testing.assert_array_equal(pyfunc(*args), cfunc(*args))

class TestArrayComparisons(TestCase):

guilhermeleobas marked this conversation as resolved.
Show resolved Hide resolved
def test_identity(self):
def check(a, b, expected):
cfunc = njit((typeof(a), typeof(b)))(pyfunc)
Expand All @@ -1780,6 +1782,59 @@ def check(a, b, expected):
check(arr, arr.T, False)
check(arr, arr[:-1], False)

# tests for record comparison
guilhermeleobas marked this conversation as resolved.
Show resolved Hide resolved
def create_mock_context():
return cgutils.lowering.Context("test_module")

def create_mock_builder(context, func):
block = context.append_basic_block(func, "entry")
return cgutils.Builder(context, block)
def test_equal_records(self):

sig = types.signature(types.boolean,
types.Record(('a', types.int32), ('b', types.int32)),
types.Record(('a', types.int32), ('b', types.int32)))

context = create_mock_context()
func = ir.Function(ir.Module(), ir.FunctionType(ir.types.Boolean(), [ir.types.Int(32), ir.types.Int(32)]), "test_function")
builder = create_mock_builder(context, func)

args = (cgutils.create_record(context, builder, sig.args[0]),
cgutils.create_record(context, builder, sig.args[1]))

result = record_is(context, builder, sig, args)
self.assertTrue(result, "Expected True for equal records")

def test_different_records(self):
sig = types.signature(types.boolean,
types.Record(('a', types.int32), ('b', types.int32)),
types.Record(('a', types.int32), ('c', types.int32)))

context = create_mock_context()
func = ir.Function(ir.Module(), ir.FunctionType(ir.types.Boolean(), [ir.types.Int(32), ir.types.Int(32)]), "test_function")
builder = create_mock_builder(context, func)

args = (cgutils.create_record(context, builder, sig.args[0]),
cgutils.create_record(context, builder, sig.args[1]))

result = record_is(context, builder, sig, args)
self.assertFalse(result, "Expected False for different records")

def test_mixed_records(self):
sig = types.signature(types.boolean,
types.Record(('a', types.int32), ('b', types.int32)),
types.Record(('a', types.int32), ('b', types.float64)))

context = create_mock_context()
func = ir.Function(ir.Module(), ir.FunctionType(ir.types.Boolean(), [ir.types.Int(32), ir.types.Int(32)]), "test_function")
builder = create_mock_builder(context, func)

args = (cgutils.create_record(context, builder, sig.args[0]),
cgutils.create_record(context, builder, sig.args[1]))

result = record_is(context, builder, sig, args)
self.assertFalse(result, "Expected False for mixed records")

# Other comparison operators ('==', etc.) are tested in test_ufuncs


Expand Down
Loading