Skip to content

Commit

Permalink
[DAG] Fold concat_vectors(concat_vectors(x,y),concat_vectors(a,b)) ->…
Browse files Browse the repository at this point in the history
… concat_vectors(x,y,a,b)

Follow-up to D107068, attempt to fold nested concat_vectors/undefs, as long as both the vector and inner subvector types are legal.

This exposed the same issue in ARM's MVE LowerCONCAT_VECTORS_i1 (raised as PR51365) and AArch64's performConcatVectorsCombine which both assumed concat_vectors only took 2 subvector operands.

Differential Revision: https://reviews.llvm.org/D107597
  • Loading branch information
RKSimon committed Aug 16, 2021
1 parent b4a1f44 commit d6fe8d3
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 420 deletions.
48 changes: 46 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19865,6 +19865,44 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
}

// Attempt to merge nested concat_vectors/undefs.
// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
SelectionDAG &DAG) {
EVT VT = N->getValueType(0);

// Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
EVT SubVT;
SDValue FirstConcat;
for (const SDValue &Op : N->ops()) {
if (Op.isUndef())
continue;
if (Op.getOpcode() != ISD::CONCAT_VECTORS)
return SDValue();
if (!FirstConcat) {
SubVT = Op.getOperand(0).getValueType();
if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
return SDValue();
FirstConcat = Op;
continue;
}
if (SubVT != Op.getOperand(0).getValueType())
return SDValue();
}
assert(FirstConcat && "Concat of all-undefs found");

SmallVector<SDValue> ConcatOps;
for (const SDValue &Op : N->ops()) {
if (Op.isUndef()) {
ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
continue;
}
ConcatOps.append(Op->op_begin(), Op->op_end());
}
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
}

// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
// most two distinct vectors the same size as the result, attempt to turn this
Expand Down Expand Up @@ -20124,13 +20162,19 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
}

// Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
// FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
if (SDValue V = combineConcatVectorOfScalars(N, DAG))
return V;

// Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT))
if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
// Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
return V;

// Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
return V;
}

if (SDValue V = combineConcatVectorOfCasts(N, DAG))
return V;
Expand Down
29 changes: 25 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10459,8 +10459,29 @@ SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
isTypeLegal(Op.getValueType()) &&
"Expected legal scalable vector type!");

if (isTypeLegal(Op.getOperand(0).getValueType()) && Op.getNumOperands() == 2)
return Op;
if (isTypeLegal(Op.getOperand(0).getValueType())) {
unsigned NumOperands = Op->getNumOperands();
assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
"Unexpected number of operands in CONCAT_VECTORS");

if (Op.getNumOperands() == 2)
return Op;

// Concat each pair of subvectors and pack into the lower half of the array.
SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
while (ConcatOps.size() > 1) {
for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
SDValue V1 = ConcatOps[I];
SDValue V2 = ConcatOps[I + 1];
EVT SubVT = V1.getValueType();
EVT PairVT = SubVT.getDoubleNumVectorElementsVT(*DAG.getContext());
ConcatOps[I / 2] =
DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), PairVT, V1, V2);
}
ConcatOps.resize(ConcatOps.size() / 2);
}
return ConcatOps[0];
}

return SDValue();
}
Expand Down Expand Up @@ -13621,7 +13642,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
// If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
// splat. The indexed instructions are going to be expecting a DUPLANE64, so
// canonicalise to that.
if (N0 == N1 && VT.getVectorNumElements() == 2) {
if (N->getNumOperands() == 2 && N0 == N1 && VT.getVectorNumElements() == 2) {
assert(VT.getScalarSizeInBits() == 64);
return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT, WidenVector(N0, DAG),
DAG.getConstant(0, dl, MVT::i64));
Expand All @@ -13636,7 +13657,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
// becomes
// (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))

if (N1Opc != ISD::BITCAST)
if (N->getNumOperands() != 2 || N1Opc != ISD::BITCAST)
return SDValue();
SDValue RHS = N1->getOperand(0);
MVT RHSTy = RHS.getValueType().getSimpleVT();
Expand Down
96 changes: 55 additions & 41 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8824,54 +8824,68 @@ static SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG,

static SDValue LowerCONCAT_VECTORS_i1(SDValue Op, SelectionDAG &DAG,
const ARMSubtarget *ST) {
SDValue V1 = Op.getOperand(0);
SDValue V2 = Op.getOperand(1);
SDLoc dl(Op);
EVT VT = Op.getValueType();
EVT Op1VT = V1.getValueType();
EVT Op2VT = V2.getValueType();
unsigned NumElts = VT.getVectorNumElements();

assert(Op1VT == Op2VT && "Operand types don't match!");
assert(VT.getScalarSizeInBits() == 1 &&
assert(Op.getValueType().getScalarSizeInBits() == 1 &&
"Unexpected custom CONCAT_VECTORS lowering");
assert(isPowerOf2_32(Op.getNumOperands()) &&
"Unexpected custom CONCAT_VECTORS lowering");
assert(ST->hasMVEIntegerOps() &&
"CONCAT_VECTORS lowering only supported for MVE");

SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG);
SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG);

// We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets
// promoted to v8i16, etc.

MVT ElType = getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT();

// Extract the vector elements from Op1 and Op2 one by one and truncate them
// to be the right size for the destination. For example, if Op1 is v4i1 then
// the promoted vector is v4i32. The result of concatentation gives a v8i1,
// which when promoted is v8i16. That means each i32 element from Op1 needs
// truncating to i16 and inserting in the result.
EVT ConcatVT = MVT::getVectorVT(ElType, NumElts);
SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT);
auto ExractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) {
EVT NewVT = NewV.getValueType();
EVT ConcatVT = ConVec.getValueType();
for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) {
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV,
DAG.getIntPtrConstant(i, dl));
ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt,
DAG.getConstant(j, dl, MVT::i32));
}
return ConVec;
auto ConcatPair = [&](SDValue V1, SDValue V2) {
EVT Op1VT = V1.getValueType();
EVT Op2VT = V2.getValueType();
assert(Op1VT == Op2VT && "Operand types don't match!");
EVT VT = Op1VT.getDoubleNumVectorElementsVT(*DAG.getContext());

SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG);
SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG);

// We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets
// promoted to v8i16, etc.
MVT ElType =
getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT();
unsigned NumElts = 2 * Op1VT.getVectorNumElements();

// Extract the vector elements from Op1 and Op2 one by one and truncate them
// to be the right size for the destination. For example, if Op1 is v4i1
// then the promoted vector is v4i32. The result of concatentation gives a
// v8i1, which when promoted is v8i16. That means each i32 element from Op1
// needs truncating to i16 and inserting in the result.
EVT ConcatVT = MVT::getVectorVT(ElType, NumElts);
SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT);
auto ExtractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) {
EVT NewVT = NewV.getValueType();
EVT ConcatVT = ConVec.getValueType();
for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) {
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV,
DAG.getIntPtrConstant(i, dl));
ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt,
DAG.getConstant(j, dl, MVT::i32));
}
return ConVec;
};
unsigned j = 0;
ConVec = ExtractInto(NewV1, ConVec, j);
ConVec = ExtractInto(NewV2, ConVec, j);

// Now return the result of comparing the subvector with zero,
// which will generate a real predicate, i.e. v4i1, v8i1 or v16i1.
return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec,
DAG.getConstant(ARMCC::NE, dl, MVT::i32));
};
unsigned j = 0;
ConVec = ExractInto(NewV1, ConVec, j);
ConVec = ExractInto(NewV2, ConVec, j);

// Now return the result of comparing the subvector with zero,
// which will generate a real predicate, i.e. v4i1, v8i1 or v16i1.
return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec,
DAG.getConstant(ARMCC::NE, dl, MVT::i32));
// Concat each pair of subvectors and pack into the lower half of the array.
SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
while (ConcatOps.size() > 1) {
for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
SDValue V1 = ConcatOps[I];
SDValue V2 = ConcatOps[I + 1];
ConcatOps[I / 2] = ConcatPair(V1, V2);
}
ConcatOps.resize(ConcatOps.size() / 2);
}
return ConcatOps[0];
}

static SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG,
Expand Down
Loading

0 comments on commit d6fe8d3

Please sign in to comment.