Skip to content

Commit

Permalink
Add basic support for TYP_MASK constants (dotnet#99743)
Browse files Browse the repository at this point in the history
This is to support fixing JitOptRepeat (dotnet#94250).
I was seeing failures in a Tensor test where `TYP_MASK`
generating instructions were getting CSE'd. When OptRepeat kicks in
and runs VN over the new IR, it wants to create a "zero" value
for the new CSE locals.

This change creates a `TYP_MASK` constant type, `simdmask_t`, like the
pre-existing `simd64_t`, `simd32_t`, etc. `simdmask_t` is basically a
`simd8_t` type, but with its own type. I expanded basically every place
that generally handles `simd64_t` with `simdmask_t` support. This might be
more than we currently need, but it seems to be a reasonable step towards
making `TYP_MASK` more first-class. However, I didn't go so far as to
support load/store of these constants, for example.
  • Loading branch information
BruceForstall committed Mar 14, 2024
1 parent 9fc8ae7 commit ca905a2
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 23 deletions.
12 changes: 12 additions & 0 deletions src/coreclr/jit/assertionprop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3042,6 +3042,18 @@ GenTree* Compiler::optVNBasedFoldConstExpr(BasicBlock* block, GenTree* parent, G
break;
}
break;

case TYP_MASK:
{
simdmask_t value = vnStore->ConstantValue<simdmask_t>(vnCns);

GenTreeVecCon* vecCon = gtNewVconNode(tree->TypeGet());
memcpy(&vecCon->gtSimdVal, &value, sizeof(simdmask_t));

conValTree = vecCon;
break;
}
break;
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down
41 changes: 21 additions & 20 deletions src/coreclr/jit/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8154,13 +8154,13 @@ CORINFO_FIELD_HANDLE emitter::emitFltOrDblConst(double constValue, emitAttr attr
// Return Value:
// A field handle representing the data offset to access the constant.
//
// Note:
// Access to inline data is 'abstracted' by a special type of static member
// (produced by eeFindJitDataOffs) which the emitter recognizes as being a reference
// to constant data, not a real static field.
//
CORINFO_FIELD_HANDLE emitter::emitSimd8Const(simd8_t constValue)
{
// Access to inline data is 'abstracted' by a special type of static member
// (produced by eeFindJitDataOffs) which the emitter recognizes as being a reference
// to constant data, not a real static field.
CLANG_FORMAT_COMMENT_ANCHOR;

unsigned cnsSize = 8;
unsigned cnsAlign = cnsSize;

Expand All @@ -8177,11 +8177,6 @@ CORINFO_FIELD_HANDLE emitter::emitSimd8Const(simd8_t constValue)

CORINFO_FIELD_HANDLE emitter::emitSimd16Const(simd16_t constValue)
{
// Access to inline data is 'abstracted' by a special type of static member
// (produced by eeFindJitDataOffs) which the emitter recognizes as being a reference
// to constant data, not a real static field.
CLANG_FORMAT_COMMENT_ANCHOR;

unsigned cnsSize = 16;
unsigned cnsAlign = cnsSize;

Expand All @@ -8199,11 +8194,6 @@ CORINFO_FIELD_HANDLE emitter::emitSimd16Const(simd16_t constValue)
#if defined(TARGET_XARCH)
CORINFO_FIELD_HANDLE emitter::emitSimd32Const(simd32_t constValue)
{
// Access to inline data is 'abstracted' by a special type of static member
// (produced by eeFindJitDataOffs) which the emitter recognizes as being a reference
// to constant data, not a real static field.
CLANG_FORMAT_COMMENT_ANCHOR;

unsigned cnsSize = 32;
unsigned cnsAlign = cnsSize;

Expand All @@ -8218,11 +8208,6 @@ CORINFO_FIELD_HANDLE emitter::emitSimd32Const(simd32_t constValue)

CORINFO_FIELD_HANDLE emitter::emitSimd64Const(simd64_t constValue)
{
// Access to inline data is 'abstracted' by a special type of static member
// (produced by eeFindJitDataOffs) which the emitter recognizes as being a reference
// to constant data, not a real static field.
CLANG_FORMAT_COMMENT_ANCHOR;

unsigned cnsSize = 64;
unsigned cnsAlign = cnsSize;

Expand All @@ -8234,6 +8219,22 @@ CORINFO_FIELD_HANDLE emitter::emitSimd64Const(simd64_t constValue)
UNATIVE_OFFSET cnum = emitDataConst(&constValue, cnsSize, cnsAlign, TYP_SIMD64);
return emitComp->eeFindJitDataOffs(cnum);
}

CORINFO_FIELD_HANDLE emitter::emitSimdMaskConst(simdmask_t constValue)
{
unsigned cnsSize = 8;
unsigned cnsAlign = cnsSize;

#ifdef TARGET_XARCH
if (emitComp->compCodeOpt() == Compiler::SMALL_CODE)
{
cnsAlign = dataSection::MIN_DATA_ALIGN;
}
#endif // TARGET_XARCH

UNATIVE_OFFSET cnum = emitDataConst(&constValue, cnsSize, cnsAlign, TYP_MASK);
return emitComp->eeFindJitDataOffs(cnum);
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/emit.h
Original file line number Diff line number Diff line change
Expand Up @@ -2506,6 +2506,7 @@ class emitter
#if defined(TARGET_XARCH)
CORINFO_FIELD_HANDLE emitSimd32Const(simd32_t constValue);
CORINFO_FIELD_HANDLE emitSimd64Const(simd64_t constValue);
CORINFO_FIELD_HANDLE emitSimdMaskConst(simdmask_t constValue);
#endif // TARGET_XARCH
#endif // FEATURE_SIMD
regNumber emitInsBinary(instruction ins, emitAttr attr, GenTree* dst, GenTree* src);
Expand Down
13 changes: 13 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3406,6 +3406,13 @@ unsigned Compiler::gtHashValue(GenTree* tree)
{
#if defined(FEATURE_SIMD)
#if defined(TARGET_XARCH)
case TYP_MASK:
{
add = genTreeHashAdd(ulo32(add), vecCon->gtSimdVal.u32[1]);
add = genTreeHashAdd(ulo32(add), vecCon->gtSimdVal.u32[0]);
break;
}

case TYP_SIMD64:
{
add = genTreeHashAdd(ulo32(add), vecCon->gtSimdVal.u32[15]);
Expand Down Expand Up @@ -12237,6 +12244,12 @@ void Compiler::gtDispConst(GenTree* tree)
vecCon->gtSimdVal.u64[6], vecCon->gtSimdVal.u64[7]);
break;
}

case TYP_MASK:
{
printf("<0x%08x, 0x%08x>", vecCon->gtSimdVal.u32[0], vecCon->gtSimdVal.u32[1]);
break;
}
#endif // TARGET_XARCH

default:
Expand Down
20 changes: 18 additions & 2 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -6510,8 +6510,9 @@ struct GenTreeVecCon : public GenTree
simd16_t gtSimd16Val;

#if defined(TARGET_XARCH)
simd32_t gtSimd32Val;
simd64_t gtSimd64Val;
simd32_t gtSimd32Val;
simd64_t gtSimd64Val;
simdmask_t gtSimdMaskVal;
#endif // TARGET_XARCH

simd_t gtSimdVal;
Expand Down Expand Up @@ -6763,6 +6764,11 @@ struct GenTreeVecCon : public GenTree
{
return gtSimd64Val.IsAllBitsSet();
}

case TYP_MASK:
{
return gtSimdMaskVal.IsAllBitsSet();
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -6810,6 +6816,11 @@ struct GenTreeVecCon : public GenTree
{
return left->gtSimd64Val == right->gtSimd64Val;
}

case TYP_MASK:
{
return left->gtSimdMaskVal == right->gtSimdMaskVal;
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -6850,6 +6861,11 @@ struct GenTreeVecCon : public GenTree
{
return gtSimd64Val.IsZero();
}

case TYP_MASK:
{
return gtSimdMaskVal.IsZero();
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down
7 changes: 7 additions & 0 deletions src/coreclr/jit/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,13 @@ CodeGen::OperandDesc CodeGen::genOperandDesc(GenTree* op)
memcpy(&constValue, &op->AsVecCon()->gtSimdVal, sizeof(simd64_t));
return OperandDesc(emit->emitSimd64Const(constValue));
}

case TYP_MASK:
{
simdmask_t constValue;
memcpy(&constValue, &op->AsVecCon()->gtSimdVal, sizeof(simdmask_t));
return OperandDesc(emit->emitSimdMaskConst(constValue));
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8760,6 +8760,8 @@ void Lowering::TryFoldCnsVecForEmbeddedBroadcast(GenTreeHWIntrinsic* parentNode,
void Lowering::TryCompressConstVecData(GenTreeStoreInd* node)
{
assert(node->Data()->IsCnsVec());
assert(node->Data()->AsVecCon()->TypeIs(TYP_SIMD32, TYP_SIMD64));

GenTreeVecCon* vecCon = node->Data()->AsVecCon();
GenTreeHWIntrinsic* broadcast = nullptr;

Expand Down
49 changes: 49 additions & 0 deletions src/coreclr/jit/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,55 @@ struct simd64_t
};
static_assert_no_msg(sizeof(simd64_t) == 64);

struct simdmask_t
{
union {
int8_t i8[8];
int16_t i16[4];
int32_t i32[2];
int64_t i64[1];
uint8_t u8[8];
uint16_t u16[4];
uint32_t u32[2];
uint64_t u64[1];
};

bool operator==(const simdmask_t& other) const
{
return (u64[0] == other.u64[0]);
}

bool operator!=(const simdmask_t& other) const
{
return !(*this == other);
}

static simdmask_t AllBitsSet()
{
simdmask_t result;

result.u64[0] = 0xFFFFFFFFFFFFFFFF;

return result;
}

bool IsAllBitsSet() const
{
return *this == AllBitsSet();
}

bool IsZero() const
{
return *this == Zero();
}

static simdmask_t Zero()
{
return {};
}
};
static_assert_no_msg(sizeof(simdmask_t) == 8);

typedef simd64_t simd_t;
#else
typedef simd16_t simd_t;
Expand Down
62 changes: 61 additions & 1 deletion src/coreclr/jit/valuenum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ ValueNumStore::ValueNumStore(Compiler* comp, CompAllocator alloc)
#if defined(TARGET_XARCH)
, m_simd32CnsMap(nullptr)
, m_simd64CnsMap(nullptr)
, m_simdMaskCnsMap(nullptr)
#endif // TARGET_XARCH
#endif // FEATURE_SIMD
, m_VNFunc0Map(nullptr)
Expand Down Expand Up @@ -1706,6 +1707,12 @@ ValueNumStore::Chunk::Chunk(CompAllocator alloc, ValueNum* pNextBaseVN, var_type
m_defs = new (alloc) Alloc<TYP_SIMD64>::Type[ChunkSize];
break;
}

case TYP_MASK:
{
m_defs = new (alloc) Alloc<TYP_MASK>::Type[ChunkSize];
break;
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -1870,6 +1877,11 @@ ValueNum ValueNumStore::VNForSimd64Con(simd64_t cnsVal)
{
return VnForConst(cnsVal, GetSimd64CnsMap(), TYP_SIMD64);
}

ValueNum ValueNumStore::VNForSimdMaskCon(simdmask_t cnsVal)
{
return VnForConst(cnsVal, GetSimdMaskCnsMap(), TYP_MASK);
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -1971,6 +1983,11 @@ ValueNum ValueNumStore::VNForGenericCon(var_types typ, uint8_t* cnsVal)
READ_VALUE(simd64_t);
return VNForSimd64Con(val);
}
case TYP_MASK:
{
READ_VALUE(simdmask_t);
return VNForSimdMaskCon(val);
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD
default:
Expand Down Expand Up @@ -2085,6 +2102,11 @@ ValueNum ValueNumStore::VNZeroForType(var_types typ)
{
return VNForSimd64Con(simd64_t::Zero());
}

case TYP_MASK:
{
return VNForSimdMaskCon(simdmask_t::Zero());
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -2175,6 +2197,11 @@ ValueNum ValueNumStore::VNAllBitsForType(var_types typ)
{
return VNForSimd64Con(simd64_t::AllBitsSet());
}

case TYP_MASK:
{
return VNForSimdMaskCon(simdmask_t::AllBitsSet());
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -2296,6 +2323,13 @@ ValueNum ValueNumStore::VNOneForSimdType(var_types simdType, var_types simdBaseT
memcpy(&simd64Val, &simdVal, sizeof(simd64_t));
return VNForSimd64Con(simd64Val);
}

case TYP_MASK:
{
// '1' doesn't make sense for TYP_MASK?
// Or should it be AllBitsSet?
unreached();
}
#endif // TARGET_XARCH

default:
Expand Down Expand Up @@ -3741,7 +3775,7 @@ simd32_t ValueNumStore::GetConstantSimd32(ValueNum argVN)
return ConstantValue<simd32_t>(argVN);
}

// Given a simd64 constant value number return its value as a simd32.
// Given a simd64 constant value number return its value as a simd64.
//
simd64_t ValueNumStore::GetConstantSimd64(ValueNum argVN)
{
Expand All @@ -3750,6 +3784,16 @@ simd64_t ValueNumStore::GetConstantSimd64(ValueNum argVN)

return ConstantValue<simd64_t>(argVN);
}

// Given a simdmask constant value number return its value as a simdmask.
//
simdmask_t ValueNumStore::GetConstantSimdMask(ValueNum argVN)
{
assert(IsVNConstant(argVN));
assert(TypeOfVN(argVN) == TYP_MASK);

return ConstantValue<simdmask_t>(argVN);
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -9221,6 +9265,13 @@ void ValueNumStore::vnDump(Compiler* comp, ValueNum vn, bool isPtr)
cnsVal.u64[6], cnsVal.u64[7]);
break;
}

case TYP_MASK:
{
simdmask_t cnsVal = GetConstantSimdMask(vn);
printf("SimdMaskCns[0x%08x, 0x%08x]", cnsVal.u32[0], cnsVal.u32[1]);
break;
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down Expand Up @@ -10661,6 +10712,15 @@ void Compiler::fgValueNumberTreeConst(GenTree* tree)
tree->gtVNPair.SetBoth(vnStore->VNForSimd64Con(simd64Val));
break;
}

case TYP_MASK:
{
simdmask_t simdmaskVal;
memcpy(&simdmaskVal, &tree->AsVecCon()->gtSimdVal, sizeof(simdmask_t));

tree->gtVNPair.SetBoth(vnStore->VNForSimdMaskCon(simdmaskVal));
break;
}
#endif // TARGET_XARCH
#endif // FEATURE_SIMD

Expand Down
Loading

0 comments on commit ca905a2

Please sign in to comment.