Skip to content

Commit

Permalink
[IR][ShuffleVector] Introduce isReplicationMask() matcher
Browse files Browse the repository at this point in the history
Avid readers of this saga may recall from previous installments,
that replication mask replicates (lol) each of the `VF` elements
in a vector `ReplicationFactor` times. For example, the mask for
`ReplicationFactor=3` and `VF=4` is: `<0,0,0,1,1,1,2,2,2,3,3,3>`.
More importantly, replication mask is used by LoopVectorizer
when using masked interleaved memory operations.

As discussed in previous installments, while it is used by LV,
and we **seem** to support masked interleaved memory operations on X86,
it's support in cost model leaves a lot to be desired:
until basically yesterday even for AVX512 we had no cost model for it.

As it has been witnessed in the recent
AVX2 `X86TTIImpl::getInterleavedMemoryOpCost()`
costmodel patches, while it is hard-enough to query the cost
of a particular assembly sequence [from llvm-mca],
afterwards the check lines LV costmodel tests must be updated manually.
This is, at the very least, boring.

Okay, now we have decent costmodel coverage for interleaving shuffles,
but now basically the same mind-killing sequence has to be performed
for replication mask. I think we can improve at least the second half
of the problem, by teaching
the `TargetTransformInfoImplCRTPBase::getUserCost()` to recognize
`Instruction::ShuffleVector` that are repetition masks,
adding exhaustive test coverage
using `-cost-model -analyze` + `utils/update_analyze_test_checks.py`

This way we can have good exhaustive coverage for cost model,
and only basic coverage for the LV costmodel.

This patch adds precise undef-aware `isReplicationMask()`,
with exhaustive test coverage.
* `InstructionsTest.ShuffleMaskIsReplicationMask` shows that
   it correctly detects all the known masks.
* `InstructionsTest.ShuffleMaskIsReplicationMask_undef`
  shows that replacing some mask elements in a known replication mask
  still allows us to recognize it as a replication mask.
  Note, with enough undef elts, we may detect a different tuple.
* `InstructionsTest.ShuffleMaskIsReplicationMask_Exhaustive_Correctness`
  shows that if we detected the replication mask with given params,
  then if we actually generate a true replication mask with said params,
  it matches element-wise ignoring undef mask elements.

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D113214
  • Loading branch information
LebedevRI committed Nov 5, 2021
1 parent 7a98761 commit 01d8759
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 1 deletion.
28 changes: 28 additions & 0 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2354,6 +2354,34 @@ class ShuffleVectorInst : public Instruction {
return isInsertSubvectorMask(ShuffleMask, NumSrcElts, NumSubElts, Index);
}

/// Return true if this shuffle mask replicates each of the \p VF elements
/// in a vector \p ReplicationFactor times.
/// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
/// <0,0,0,1,1,1,2,2,2,3,3,3>
static bool isReplicationMask(ArrayRef<int> Mask, int &ReplicationFactor,
int &VF);
static bool isReplicationMask(const Constant *Mask, int &ReplicationFactor,
int &VF) {
assert(Mask->getType()->isVectorTy() && "Shuffle needs vector constant.");
// Not possible to express a shuffle mask for a scalable vector for this
// case.
if (isa<ScalableVectorType>(Mask->getType()))
return false;
SmallVector<int, 16> MaskAsInts;
getShuffleMask(Mask, MaskAsInts);
return isReplicationMask(MaskAsInts, ReplicationFactor, VF);
}

/// Return true if this shuffle mask is an replication mask.
bool isReplicationMask(int &ReplicationFactor, int &VF) const {
// Not possible to express a shuffle mask for a scalable vector for this
// case.
if (isa<ScalableVectorType>(getType()))
return false;

return isReplicationMask(ShuffleMask, ReplicationFactor, VF);
}

/// Change values in a shuffle permute mask assuming the two vector operands
/// of length InVecNumElts have swapped position.
static void commuteShuffleMask(MutableArrayRef<int> Mask,
Expand Down
66 changes: 66 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2436,6 +2436,72 @@ bool ShuffleVectorInst::isConcat() const {
return isIdentityMaskImpl(getShuffleMask(), NumMaskElts);
}

static bool isReplicationMaskWithParams(ArrayRef<int> Mask,
int ReplicationFactor, int VF) {
assert(Mask.size() == (unsigned)ReplicationFactor * VF &&
"Unexpected mask size.");

for (int CurrElt : seq(0, VF)) {
ArrayRef<int> CurrSubMask = Mask.take_front(ReplicationFactor);
assert(CurrSubMask.size() == (unsigned)ReplicationFactor &&
"Run out of mask?");
Mask = Mask.drop_front(ReplicationFactor);
if (!all_of(CurrSubMask, [CurrElt](int MaskElt) {
return MaskElt == UndefMaskElem || MaskElt == CurrElt;
}))
return false;
}
assert(Mask.empty() && "Did not consume the whole mask?");

return true;
}

bool ShuffleVectorInst::isReplicationMask(ArrayRef<int> Mask,
int &ReplicationFactor, int &VF) {
// undef-less case is trivial.
if (none_of(Mask, [](int MaskElt) { return MaskElt == UndefMaskElem; })) {
ReplicationFactor =
Mask.take_while([](int MaskElt) { return MaskElt == 0; }).size();
if (ReplicationFactor == 0 || Mask.size() % ReplicationFactor != 0)
return false;
VF = Mask.size() / ReplicationFactor;
return isReplicationMaskWithParams(Mask, ReplicationFactor, VF);
}

// However, if the mask contains undef's, we have to enumerate possible tuples
// and pick one. There are bounds on replication factor: [1, mask size]
// (where RF=1 is an identity shuffle, RF=mask size is a broadcast shuffle)
// Additionally, mask size is a replication factor multiplied by vector size,
// which further significantly reduces the search space.

// Before doing that, let's perform basic sanity check first.
int Largest = -1;
for (int MaskElt : Mask) {
if (MaskElt == UndefMaskElem)
continue;
// Elements must be in non-decreasing order.
if (MaskElt < Largest)
return false;
Largest = std::max(Largest, MaskElt);
}

// Prefer larger replication factor if all else equal.
for (int PossibleReplicationFactor :
reverse(seq_inclusive<unsigned>(1, Mask.size()))) {
if (Mask.size() % PossibleReplicationFactor != 0)
continue;
int PossibleVF = Mask.size() / PossibleReplicationFactor;
if (!isReplicationMaskWithParams(Mask, PossibleReplicationFactor,
PossibleVF))
continue;
ReplicationFactor = PossibleReplicationFactor;
VF = PossibleVF;
return true;
}

return false;
}

//===----------------------------------------------------------------------===//
// InsertValueInst Class
//===----------------------------------------------------------------------===//
Expand Down
90 changes: 89 additions & 1 deletion llvm/unittests/IR/InstructionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Instructions.h"
#include "llvm/ADT/CombinationGenerator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
Expand Down Expand Up @@ -1115,6 +1117,92 @@ TEST(InstructionsTest, ShuffleMaskQueries) {
delete Id15;
}

TEST(InstructionsTest, ShuffleMaskIsReplicationMask) {
for (int ReplicationFactor : seq_inclusive(1, 8)) {
for (int VF : seq_inclusive(1, 8)) {
const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF);
int GuessedReplicationFactor = -1, GuessedVF = -1;
EXPECT_TRUE(ShuffleVectorInst::isReplicationMask(
ReplicatedMask, GuessedReplicationFactor, GuessedVF));
EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor);
EXPECT_EQ(GuessedVF, VF);
}
}
}

TEST(InstructionsTest, ShuffleMaskIsReplicationMask_undef) {
for (int ReplicationFactor : seq_inclusive(1, 6)) {
for (int VF : seq_inclusive(1, 4)) {
const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF);
int GuessedReplicationFactor = -1, GuessedVF = -1;

// If we change some mask elements to undef, we should still match.

SmallVector<SmallVector<bool>> ElementChoices(ReplicatedMask.size(),
{false, true});

CombinationGenerator<bool, decltype(ElementChoices)::value_type,
/*variable_smallsize=*/4>
G(ElementChoices);

G.generate([&](ArrayRef<bool> UndefOverrides) -> bool {
SmallVector<int> AdjustedMask;
AdjustedMask.reserve(ReplicatedMask.size());
for (auto I : zip(ReplicatedMask, UndefOverrides))
AdjustedMask.emplace_back(std::get<1>(I) ? -1 : std::get<0>(I));
assert(AdjustedMask.size() == ReplicatedMask.size() &&
"Size misprediction");

EXPECT_TRUE(ShuffleVectorInst::isReplicationMask(
AdjustedMask, GuessedReplicationFactor, GuessedVF));
// Do not check GuessedReplicationFactor and GuessedVF,
// with enough undef's we may deduce a different tuple.

return /*Abort=*/false;
});
}
}
}

TEST(InstructionsTest, ShuffleMaskIsReplicationMask_Exhaustive_Correctness) {
for (int ShufMaskNumElts : seq_inclusive(1, 8)) {
SmallVector<int> PossibleShufMaskElts;
PossibleShufMaskElts.reserve(ShufMaskNumElts + 2);
for (int PossibleShufMaskElt : seq_inclusive(-1, ShufMaskNumElts))
PossibleShufMaskElts.emplace_back(PossibleShufMaskElt);
assert(PossibleShufMaskElts.size() == ShufMaskNumElts + 2U &&
"Size misprediction");

SmallVector<SmallVector<int>> ElementChoices(ShufMaskNumElts,
PossibleShufMaskElts);

CombinationGenerator<int, decltype(ElementChoices)::value_type,
/*variable_smallsize=*/4>
G(ElementChoices);

G.generate([&](ArrayRef<int> Mask) -> bool {
int GuessedReplicationFactor = -1, GuessedVF = -1;
bool Match = ShuffleVectorInst::isReplicationMask(
Mask, GuessedReplicationFactor, GuessedVF);
if (!Match)
return /*Abort=*/false;

const auto ActualMask =
createReplicatedMask(GuessedReplicationFactor, GuessedVF);
EXPECT_EQ(Mask.size(), ActualMask.size());
for (auto I : zip(Mask, ActualMask)) {
int Elt = std::get<0>(I);
int ActualElt = std::get<0>(I);

if (Elt != -1)
EXPECT_EQ(Elt, ActualElt);
}

return /*Abort=*/false;
});
}
}

TEST(InstructionsTest, GetSplat) {
// Create the elements for various constant vectors.
LLVMContext Ctx;
Expand Down

0 comments on commit 01d8759

Please sign in to comment.