Skip to content

Commit

Permalink
[SLP]Track repeated reduced value as it might be vectorized
Browse files Browse the repository at this point in the history
Need to track changes with the repeated reduced value, since it might be
vectorized in the next attempt for reduction vectorization, to correctly
generate the code and avoid compiler crash.

Fixes #111887
  • Loading branch information
alexey-bataev committed Oct 10, 2024
1 parent 1954869 commit 4b5018d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
22 changes: 12 additions & 10 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,8 +1293,7 @@ class BoUpSLP {
using InstrList = SmallVector<Instruction *, 16>;
using ValueSet = SmallPtrSet<Value *, 16>;
using StoreList = SmallVector<StoreInst *, 8>;
using ExtraValueToDebugLocsMap =
MapVector<Value *, SmallVector<Instruction *, 2>>;
using ExtraValueToDebugLocsMap = SmallDenseSet<Value *, 4>;
using OrdersType = SmallVector<unsigned, 4>;

BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti,
Expand Down Expand Up @@ -6322,7 +6321,7 @@ void BoUpSLP::buildExternalUses(
continue;

// Check if the scalar is externally used as an extra arg.
const auto *ExtI = ExternallyUsedValues.find(Scalar);
const auto ExtI = ExternallyUsedValues.find(Scalar);
if (ExtI != ExternallyUsedValues.end()) {
int FoundLane = Entry->findLaneForValue(Scalar);
LLVM_DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane "
Expand Down Expand Up @@ -18820,7 +18819,7 @@ class HorizontalReduction {
// List of the values that were reduced in other trees as part of gather
// nodes and thus requiring extract if fully vectorized in other trees.
SmallPtrSet<Value *, 4> RequiredExtract;
Value *VectorizedTree = nullptr;
WeakTrackingVH VectorizedTree = nullptr;
bool CheckForReusedReductionOps = false;
// Try to vectorize elements based on their type.
SmallVector<InstructionsState> States;
Expand Down Expand Up @@ -18916,6 +18915,7 @@ class HorizontalReduction {
bool SameScaleFactor = false;
bool OptReusedScalars = IsSupportedHorRdxIdentityOp &&
SameValuesCounter.size() != Candidates.size();
BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
if (OptReusedScalars) {
SameScaleFactor =
(RdxKind == RecurKind::Add || RdxKind == RecurKind::FAdd ||
Expand All @@ -18936,6 +18936,7 @@ class HorizontalReduction {
emitScaleForReusedOps(Candidates.front(), Builder, Cnt);
VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal);
VectorizedVals.try_emplace(OrigV, Cnt);
ExternallyUsedValues.insert(OrigV);
continue;
}
}
Expand Down Expand Up @@ -19015,17 +19016,18 @@ class HorizontalReduction {
V.reorderBottomToTop(/*IgnoreReorder=*/true);
// Keep extracted other reduction values, if they are used in the
// vectorization trees.
BoUpSLP::ExtraValueToDebugLocsMap LocalExternallyUsedValues;
BoUpSLP::ExtraValueToDebugLocsMap LocalExternallyUsedValues(
ExternallyUsedValues);
// The reduction root is used as the insertion point for new
// instructions, so set it as externally used to prevent it from being
// deleted.
LocalExternallyUsedValues[ReductionRoot];
LocalExternallyUsedValues.insert(ReductionRoot);
for (unsigned Cnt = 0, Sz = ReducedVals.size(); Cnt < Sz; ++Cnt) {
if (Cnt == I || (ShuffledExtracts && Cnt == I - 1))
continue;
for (Value *V : ReducedVals[Cnt])
if (isa<Instruction>(V))
LocalExternallyUsedValues[TrackedVals[V]];
LocalExternallyUsedValues.insert(TrackedVals[V]);
}
if (!IsSupportedHorRdxIdentityOp) {
// Number of uses of the candidates in the vector of values.
Expand Down Expand Up @@ -19054,21 +19056,21 @@ class HorizontalReduction {
// Check if the scalar was vectorized as part of the vectorization
// tree but not the top node.
if (!VLScalars.contains(RdxVal) && V.isVectorized(RdxVal)) {
LocalExternallyUsedValues[RdxVal];
LocalExternallyUsedValues.insert(RdxVal);
continue;
}
Value *OrigV = TrackedToOrig.at(RdxVal);
unsigned NumOps =
VectorizedVals.lookup(OrigV) + At(SameValuesCounter, OrigV);
if (NumOps != ReducedValsToOps.at(OrigV).size())
LocalExternallyUsedValues[RdxVal];
LocalExternallyUsedValues.insert(RdxVal);
}
// Do not need the list of reused scalars in regular mode anymore.
if (!IsSupportedHorRdxIdentityOp)
SameValuesCounter.clear();
for (Value *RdxVal : VL)
if (RequiredExtract.contains(RdxVal))
LocalExternallyUsedValues[RdxVal];
LocalExternallyUsedValues.insert(RdxVal);
V.buildExternalUses(LocalExternallyUsedValues);

V.computeMinimumValueSizes();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S --passes=slp-vectorizer -mtriple=riscv64-unknown-linux-gnu -mattr=+v < %s | FileCheck %s

define void @test() {
; CHECK-LABEL: define void @test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = call <4 x i16> @llvm.experimental.vp.strided.load.v4i16.p0.i64(ptr align 2 null, i64 6, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, i32 4)
; CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr null, align 2
; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i16> [[TMP0]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = call i16 @llvm.vector.reduce.smax.v4i16(<4 x i16> [[TMP2]])
; CHECK-NEXT: [[TMP4:%.*]] = call i16 @llvm.smax.i16(i16 [[TMP1]], i16 [[TMP3]])
; CHECK-NEXT: [[TMP5:%.*]] = call i16 @llvm.smax.i16(i16 [[TMP4]], i16 0)
; CHECK-NEXT: [[TMP6:%.*]] = tail call i16 @llvm.smax.i16(i16 [[TMP5]], i16 0)
; CHECK-NEXT: ret void
;
entry:
%0 = load i16, ptr null, align 2
%1 = xor i16 %0, 0
%2 = tail call i16 @llvm.smax.i16(i16 %1, i16 %0)
%3 = tail call i16 @llvm.smax.i16(i16 0, i16 %2)
%4 = load i16, ptr getelementptr inbounds (i8, ptr null, i64 6), align 2
%5 = xor i16 %4, 0
%6 = tail call i16 @llvm.smax.i16(i16 %5, i16 %0)
%7 = tail call i16 @llvm.smax.i16(i16 %3, i16 %6)
%8 = load i16, ptr getelementptr (i8, ptr null, i64 12), align 2
%9 = xor i16 %8, 0
%10 = tail call i16 @llvm.smax.i16(i16 %9, i16 %0)
%11 = tail call i16 @llvm.smax.i16(i16 %7, i16 %10)
%12 = load i16, ptr getelementptr (i8, ptr null, i64 18), align 2
%13 = xor i16 %12, 0
%14 = tail call i16 @llvm.smax.i16(i16 %13, i16 %0)
%15 = tail call i16 @llvm.smax.i16(i16 %11, i16 %14)
%16 = tail call i16 @llvm.smax.i16(i16 %15, i16 0)
ret void
}

0 comments on commit 4b5018d

Please sign in to comment.