Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: Added Sve.CreateBreakPropagateMask #104704

Merged
merged 13 commits into from
Jul 17, 2024
47 changes: 27 additions & 20 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2286,10 +2286,16 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}

#if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_ARM64)
auto convertToMaskIfNeeded = [&](GenTree*& op) {
if (!varTypeIsMask(op))
{
op = gtNewSimdCvtVectorToMaskNode(TYP_MASK, op, simdBaseJitType, simdSize);
}
};

if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrinsic))
{
assert(numArgs > 0);
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);

switch (intrinsic)
{
Expand All @@ -2304,14 +2310,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case NI_Sve_TestFirstTrue:
case NI_Sve_TestLastTrue:
{
GenTree* op2 = retNode->AsHWIntrinsic()->Op(2);

// HWInstrinsic requires a mask for op2
if (!varTypeIsMask(op2))
{
retNode->AsHWIntrinsic()->Op(2) =
gtNewSimdCvtVectorToMaskNode(TYP_MASK, op2, simdBaseJitType, simdSize);
}
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(2));
break;
}

Expand All @@ -2324,26 +2324,17 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case NI_Sve_CreateBreakAfterPropagateMask:
case NI_Sve_CreateBreakBeforePropagateMask:
{
GenTree* op3 = retNode->AsHWIntrinsic()->Op(3);

// HWInstrinsic requires a mask for op3
if (!varTypeIsMask(op3))
{
retNode->AsHWIntrinsic()->Op(3) =
gtNewSimdCvtVectorToMaskNode(TYP_MASK, op3, simdBaseJitType, simdSize);
}
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(3));
break;
}

default:
break;
}

if (!varTypeIsMask(op1))
{
// Op1 input is a vector. HWInstrinsic requires a mask.
retNode->AsHWIntrinsic()->Op(1) = gtNewSimdCvtVectorToMaskNode(TYP_MASK, op1, simdBaseJitType, simdSize);
}
// HWInstrinsic requires a mask for op1
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(1));

if (HWIntrinsicInfo::IsMultiReg(intrinsic))
{
Expand All @@ -2354,6 +2345,22 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
}

if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsic))
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
{
switch (intrinsic)
{
case NI_Sve_CreateBreakPropagateMask:
{
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(1));
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(2));
break;
}

default:
break;
}
}

if (retType != nodeRetType)
{
// HWInstrinsic returns a mask, but all returns must be vectors, so convert mask to vector.
Expand Down
34 changes: 26 additions & 8 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;
bool hasShift = false;

insOpts embOpt = opt;
switch (intrinEmbMask.id)
{
case NI_Sve_ShiftLeftLogical:
Expand All @@ -689,6 +690,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
hasShift = true;
break;

case NI_Sve_CreateBreakPropagateMask:
embOpt = INS_OPTS_SCALABLE_B;
break;

default:
break;
}
Expand All @@ -699,13 +704,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op2, op2->AsHWIntrinsic());
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(), opt,
sopt);
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(),
embOpt, sopt);
}
}
else
{
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg3, opt, sopt);
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg3, embOpt, sopt);
}
};

Expand All @@ -714,12 +719,25 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the
// destination using /Z.

assert(targetReg != embMaskOp2Reg);
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
switch (intrinEmbMask.id)
{
case NI_Sve_CreateBreakPropagateMask:
assert(targetReg != embMaskOp1Reg);
GetEmitter()->emitIns_Mov(INS_sve_mov, emitSize, targetReg, embMaskOp2Reg,
/* canSkip */ true);
emitInsHelper(targetReg, maskReg, embMaskOp1Reg);
break;

// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
// and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
default:
assert(targetReg != embMaskOp2Reg);
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg,
embMaskOp1Reg, opt);

// Finally, perform the actual "predicated" operation so that `targetReg` is the first
// operand and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
break;
}
}
else if (targetReg != falseReg)
{
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ HARDWARE_INTRINSIC(Sve, CreateBreakAfterMask,
HARDWARE_INTRINSIC(Sve, CreateBreakAfterPropagateMask, -1, 3, true, {INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, CreateBreakBeforeMask, -1, 2, true, {INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, CreateBreakBeforePropagateMask, -1, 3, true, {INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, CreateBreakPropagateMask, -1, -1, false, {INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_HasRMWSemantics)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should't this have HW_Flag_ExplicitMaskedOperation too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the above comment, I can't combine ExplicitMaskedOperation and EmbeddedMaskedOperation.

HARDWARE_INTRINSIC(Sve, CreateFalseMaskByte, -1, 0, false, {INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskDouble, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskInt16, -1, 0, false, {INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
Expand Down
40 changes: 31 additions & 9 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,8 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
}
else
{
SingleTypeRegSet predMask = RBM_ALLMASK.GetPredicateRegSet();
bool tgtPrefEmbOp2 = false;
SingleTypeRegSet predMask = RBM_ALLMASK.GetPredicateRegSet();
if (intrin.id == NI_Sve_ConditionalSelect)
{
// If this is conditional select, make sure to check the embedded
Expand All @@ -1658,16 +1659,26 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
predMask = RBM_LOWMASK.GetPredicateRegSet();
}

// Special-case, CreateBreakPropagateMask's op2 is the RMW node.
if (intrinEmb.id == NI_Sve_CreateBreakPropagateMask)
{
assert(embOp2Node->isRMWHWIntrinsic(compiler));
assert(!tgtPrefOp1);
assert(!tgtPrefOp2);
tgtPrefEmbOp2 = true;
}
}
}
else if (HWIntrinsicInfo::IsLowMaskedOperation(intrin.id))
{
predMask = RBM_LOWMASK.GetPredicateRegSet();
}

if (tgtPrefOp2)
if (tgtPrefOp2 || tgtPrefEmbOp2)
{
srcCount += BuildDelayFreeUses(intrin.op1, intrin.op2, predMask);
assert(!tgtPrefOp1);
srcCount += BuildDelayFreeUses(intrin.op1, nullptr, predMask);
}
else
{
Expand Down Expand Up @@ -1983,15 +1994,26 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
break;
}

tgtPrefUse = BuildUse(embOp2Node->Op(1));
srcCount += 1;

for (size_t argNum = 2; argNum <= numArgs; argNum++)
size_t prefUseOpNum = 1;
if (intrinEmb.id == NI_Sve_CreateBreakPropagateMask)
{
prefUseOpNum = 2;
}
GenTree* prefUseNode = embOp2Node->Op(prefUseOpNum);
for (size_t argNum = 1; argNum <= numArgs; argNum++)
{
srcCount += BuildDelayFreeUses(embOp2Node->Op(argNum), embOp2Node->Op(1));
if (argNum == prefUseOpNum)
{
tgtPrefUse = BuildUse(prefUseNode);
srcCount += 1;
}
else
{
srcCount += BuildDelayFreeUses(embOp2Node->Op(argNum), prefUseNode);
}
}

srcCount += BuildDelayFreeUses(intrin.op3, embOp2Node->Op(1));
srcCount += BuildDelayFreeUses(intrin.op3, prefUseNode);
}
}
else if (intrin.op2 != nullptr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,55 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateBreakBeforePropagateMask(Vector<ulong> mask, Vector<ulong> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }


/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<sbyte> CreateBreakPropagateMask(Vector<sbyte> totalMask, Vector<sbyte> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<short> CreateBreakPropagateMask(Vector<short> totalMask, Vector<short> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<int> CreateBreakPropagateMask(Vector<int> totalMask, Vector<int> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<long> CreateBreakPropagateMask(Vector<long> totalMask, Vector<long> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<byte> CreateBreakPropagateMask(Vector<byte> totalMask, Vector<byte> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ushort> CreateBreakPropagateMask(Vector<ushort> totalMask, Vector<ushort> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<uint> CreateBreakPropagateMask(Vector<uint> totalMask, Vector<uint> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ulong> CreateBreakPropagateMask(Vector<ulong> totalMask, Vector<ulong> fromMask) { throw new PlatformNotSupportedException(); }


/// Set all predicate elements to false

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,55 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateBreakBeforePropagateMask(Vector<ulong> mask, Vector<ulong> left, Vector<ulong> right) => CreateBreakBeforePropagateMask(mask, left, right);


/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<sbyte> CreateBreakPropagateMask(Vector<sbyte> totalMask, Vector<sbyte> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<short> CreateBreakPropagateMask(Vector<short> totalMask, Vector<short> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<int> CreateBreakPropagateMask(Vector<int> totalMask, Vector<int> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<long> CreateBreakPropagateMask(Vector<long> totalMask, Vector<long> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<byte> CreateBreakPropagateMask(Vector<byte> totalMask, Vector<byte> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ushort> CreateBreakPropagateMask(Vector<ushort> totalMask, Vector<ushort> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<uint> CreateBreakPropagateMask(Vector<uint> totalMask, Vector<uint> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ulong> CreateBreakPropagateMask(Vector<ulong> totalMask, Vector<ulong> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);


/// Set all predicate elements to false

/// <summary>
Expand Down
Loading
Loading