Skip to content

Commit

Permalink
[Mono] Add SIMD intrinsic for Vector64/128 comparisons (dotnet#65128)
Browse files Browse the repository at this point in the history
* Add vector comparison intrinsics

* Add EqualsAll and EqualsAny intrinsics

* Remove broken EqualsAny

* Fix EqualsAny

* Enable xequal also for arm64

* Fix xzero type

* Fix bad merge

* Add guards for invalid types

* Revert unrelated change

* Extract duplicate code blocks to a new function

* Fix EqualsAny

* Fix typo + code improvements
  • Loading branch information
simonrozsival committed Feb 25, 2022
1 parent 9012b23 commit afc2f11
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 85 deletions.
149 changes: 76 additions & 73 deletions src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -9465,79 +9465,6 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
values [ins->dreg] = LLVMBuildSExt (builder, cmp, LLVMTypeOf (lhs), "");
break;
}
case OP_XEQUAL: {
LLVMTypeRef t;
LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle;
int nelems;

#if defined(TARGET_WASM)
/* The wasm code generator doesn't understand the shuffle/and code sequence below */
LLVMValueRef val;
if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) {
val = LLVMIsNull (lhs) ? rhs : lhs;
nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));

IntrinsicId intrins = (IntrinsicId)0;
switch (nelems) {
case 16:
intrins = INTRINS_WASM_ANYTRUE_V16;
break;
case 8:
intrins = INTRINS_WASM_ANYTRUE_V8;
break;
case 4:
intrins = INTRINS_WASM_ANYTRUE_V4;
break;
case 2:
intrins = INTRINS_WASM_ANYTRUE_V2;
break;
default:
g_assert_not_reached ();
}
/* res = !wasm.anytrue (val) */
values [ins->dreg] = call_intrins (ctx, intrins, &val, "");
values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname);
break;
}
#endif
LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs));

//%c = icmp sgt <16 x i8> %a0, %a1
if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ())
cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, "");
else
cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, "");
nelems = LLVMGetVectorSize (LLVMTypeOf (cmp));

LLVMTypeRef elemt;
if (srcelemt == LLVMDoubleType ())
elemt = LLVMInt64Type ();
else if (srcelemt == LLVMFloatType ())
elemt = LLVMInt32Type ();
else
elemt = srcelemt;

t = LLVMVectorType (elemt, nelems);
cmp = LLVMBuildSExt (builder, cmp, t, "");
// cmp is a <nelems x elemt> vector, each element is either 0xff... or 0
int half = nelems / 2;
while (half >= 1) {
// AND the top and bottom halfes into the bottom half
for (int i = 0; i < half; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE);
for (int i = half; i < nelems; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE);
shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), "");
cmp = LLVMBuildAnd (builder, cmp, shuffle, "");
half = half / 2;
}
// Extract [0]
LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), "");
// convert to 0/1
LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), "");
values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), "");
break;
}
case OP_POPCNT32:
values [ins->dreg] = call_intrins (ctx, INTRINS_CTPOP_I32, &lhs, "");
break;
Expand Down Expand Up @@ -9616,6 +9543,82 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
}
#endif

#if defined(TARGET_ARM64) || defined(TARGET_X86) || defined(TARGET_AMD64) || defined(TARGET_WASM)
case OP_XEQUAL: {
LLVMTypeRef t;
LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle;
int nelems;

#if defined(TARGET_WASM)
/* The wasm code generator doesn't understand the shuffle/and code sequence below */
LLVMValueRef val;
if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) {
val = LLVMIsNull (lhs) ? rhs : lhs;
nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));

IntrinsicId intrins = (IntrinsicId)0;
switch (nelems) {
case 16:
intrins = INTRINS_WASM_ANYTRUE_V16;
break;
case 8:
intrins = INTRINS_WASM_ANYTRUE_V8;
break;
case 4:
intrins = INTRINS_WASM_ANYTRUE_V4;
break;
case 2:
intrins = INTRINS_WASM_ANYTRUE_V2;
break;
default:
g_assert_not_reached ();
}
/* res = !wasm.anytrue (val) */
values [ins->dreg] = call_intrins (ctx, intrins, &val, "");
values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname);
break;
}
#endif
LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs));

//%c = icmp sgt <16 x i8> %a0, %a1
if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ())
cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, "");
else
cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, "");
nelems = LLVMGetVectorSize (LLVMTypeOf (cmp));

LLVMTypeRef elemt;
if (srcelemt == LLVMDoubleType ())
elemt = LLVMInt64Type ();
else if (srcelemt == LLVMFloatType ())
elemt = LLVMInt32Type ();
else
elemt = srcelemt;

t = LLVMVectorType (elemt, nelems);
cmp = LLVMBuildSExt (builder, cmp, t, "");
// cmp is a <nelems x elemt> vector, each element is either 0xff... or 0
int half = nelems / 2;
while (half >= 1) {
// AND the top and bottom halfes into the bottom half
for (int i = 0; i < half; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE);
for (int i = half; i < nelems; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE);
shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), "");
cmp = LLVMBuildAnd (builder, cmp, shuffle, "");
half = half / 2;
}
// Extract [0]
LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), "");
// convert to 0/1
LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), "");
values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), "");
break;
}
#endif

#if defined(TARGET_ARM64)

case OP_XOP_I4_I4:
Expand Down
112 changes: 100 additions & 12 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,29 @@ emit_xcompare (MonoCompile *cfg, MonoClass *klass, MonoTypeEnum etype, MonoInst
return ins;
}

static MonoInst*
emit_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2)
{
return emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg);
}

static MonoInst*
emit_not_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2)
{
MonoInst *ins = emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg);
int sreg = ins->dreg;
int dreg = alloc_ireg (cfg);
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0);
EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1);
return ins;
}

static MonoInst*
emit_xzero (MonoCompile *cfg, MonoClass *klass)
{
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
}

static gboolean
is_intrinsics_vector_type (MonoType *vector_type)
{
Expand Down Expand Up @@ -492,7 +515,7 @@ emit_vector_create_elementwise (
{
int op = type_to_insert_op (etype);
MonoClass *vklass = mono_class_from_mono_type_internal (vtype);
MonoInst *ins = emit_simd_ins (cfg, vklass, OP_XZERO, -1, -1);
MonoInst *ins = emit_xzero (cfg, vklass);
for (int i = 0; i < fsig->param_count; ++i) {
ins = emit_simd_ins (cfg, vklass, op, ins->dreg, args [i]->dreg);
ins->inst_c0 = i;
Expand Down Expand Up @@ -590,10 +613,17 @@ static guint16 sri_vector_methods [] = {
SN_CreateScalar,
SN_CreateScalarUnsafe,
SN_Divide,
SN_Equals,
SN_EqualsAll,
SN_EqualsAny,
SN_Floor,
SN_GetElement,
SN_GetLower,
SN_GetUpper,
SN_GreaterThan,
SN_GreaterThanOrEqual,
SN_LessThan,
SN_LessThanOrEqual,
SN_Max,
SN_Min,
SN_Multiply,
Expand Down Expand Up @@ -788,6 +818,27 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR, -1, arg0_type, fsig, args);
case SN_CreateScalarUnsafe:
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR_UNSAFE, -1, arg0_type, fsig, args);
case SN_Equals:
case SN_EqualsAll:
case SN_EqualsAny: {
MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
return NULL;

switch (id) {
case SN_Equals:
return emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
case SN_EqualsAll:
return emit_xequal (cfg, klass, args [0], args [1]);
case SN_EqualsAny: {
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
MonoInst *cmp_eq = emit_xcompare (cfg, arg_class, arg0_type, args [0], args [1]);
MonoInst *zero = emit_xzero (cfg, arg_class);
return emit_not_xequal (cfg, arg_class, cmp_eq, zero);
}
default: g_assert_not_reached ();
}
}
case SN_GetElement: {
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
MonoType *etype = mono_class_get_context (arg_class)->class_inst->type_argv [0];
Expand All @@ -809,6 +860,34 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
int op = id == SN_GetLower ? OP_XLOWER : OP_XUPPER;
return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args);
}
case SN_GreaterThan:
case SN_GreaterThanOrEqual:
case SN_LessThan:
case SN_LessThanOrEqual: {
MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
return NULL;

gboolean is_unsigned = type_is_unsigned (fsig->params [0]);
MonoInst *ins = emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
switch (id) {
case SN_GreaterThan:
ins->inst_c0 = is_unsigned ? CMP_GT_UN : CMP_GT;
break;
case SN_GreaterThanOrEqual:
ins->inst_c0 = is_unsigned ? CMP_GE_UN : CMP_GE;
break;
case SN_LessThan:
ins->inst_c0 = is_unsigned ? CMP_LT_UN : CMP_LT;
break;
case SN_LessThanOrEqual:
ins->inst_c0 = is_unsigned ? CMP_LE_UN : CMP_LE;
break;
default:
g_assert_not_reached ();
}
return ins;
}
case SN_Negate:
case SN_OnesComplement: {
#ifdef TARGET_ARM64
Expand Down Expand Up @@ -879,6 +958,8 @@ static guint16 vector64_vector128_t_methods [] = {
SN_get_Count,
SN_get_IsSupported,
SN_get_Zero,
SN_op_Equality,
SN_op_Inequality,
};

static MonoInst*
Expand Down Expand Up @@ -928,10 +1009,10 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
return ins;
}
case SN_get_Zero: {
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
return emit_xzero (cfg, klass);
}
case SN_get_AllBitsSet: {
MonoInst *ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
MonoInst *ins = emit_xzero (cfg, klass);
return emit_xcompare (cfg, klass, etype->type, ins, ins);
}
case SN_Equals: {
Expand All @@ -941,6 +1022,16 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
}
break;
}
case SN_op_Equality:
case SN_op_Inequality:
g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN &&
mono_metadata_type_equal (fsig->params [0], type) &&
mono_metadata_type_equal (fsig->params [1], type));
switch (id) {
case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]);
case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]);
default: g_assert_not_reached ();
}
default:
break;
}
Expand Down Expand Up @@ -1086,7 +1177,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
return ins;
case SN_get_Zero:
g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type));
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
return emit_xzero (cfg, klass);
case SN_get_One: {
g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type));
MonoInst *one = NULL;
Expand Down Expand Up @@ -1115,7 +1206,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
}
case SN_get_AllBitsSet: {
/* Compare a zero vector with itself */
ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
ins = emit_xzero (cfg, klass);
return emit_xcompare (cfg, klass, etype->type, ins, ins);
}
case SN_get_Item: {
Expand Down Expand Up @@ -1222,14 +1313,11 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN &&
mono_metadata_type_equal (fsig->params [0], type) &&
mono_metadata_type_equal (fsig->params [1], type));
ins = emit_simd_ins (cfg, klass, OP_XEQUAL, args [0]->dreg, args [1]->dreg);
if (id == SN_op_Inequality) {
int sreg = ins->dreg;
int dreg = alloc_ireg (cfg);
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0);
EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1);
switch (id) {
case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]);
case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]);
default: g_assert_not_reached ();
}
return ins;
case SN_GreaterThan:
case SN_GreaterThanOrEqual:
case SN_LessThan:
Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/mini/simd-methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ METHOD(Create)
METHOD(CreateScalar)
METHOD(CreateScalarUnsafe)
METHOD(ConditionalSelect)
METHOD(EqualsAll)
METHOD(EqualsAny)
METHOD(GetElement)
METHOD(GetLower)
METHOD(GetUpper)
Expand Down

0 comments on commit afc2f11

Please sign in to comment.