[MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignSymbols (#76397)
This commit is contained in:
committed by
GitHub
parent
945c2e6d92
commit
ff80414620
@@ -290,6 +290,11 @@ public:
|
||||
/// the symbols in two spaces are aligned.
|
||||
bool isAligned(const PresburgerSpace &other, VarKind kind) const;
|
||||
|
||||
/// Merge and align symbol variables of `this` and `other` with respect to
|
||||
/// identifiers. After this operation the symbol variables of both spaces have
|
||||
/// the same identifiers in the same order.
|
||||
void mergeAndAlignSymbols(PresburgerSpace &other);
|
||||
|
||||
void print(llvm::raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
|
||||
@@ -294,6 +294,40 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
|
||||
// `identifiers` remains same.
|
||||
}
|
||||
|
||||
void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) {
|
||||
assert(usingIds && other.usingIds &&
|
||||
"Both spaces need to have identifers to merge & align");
|
||||
|
||||
// First merge & align identifiers into `other` from `this`.
|
||||
unsigned kindBeginOffset = other.getVarKindOffset(VarKind::Symbol);
|
||||
unsigned i = 0;
|
||||
for (const Identifier *identifier =
|
||||
identifiers.begin() + getVarKindOffset(VarKind::Symbol);
|
||||
identifier != identifiers.begin() + getVarKindEnd(VarKind::Symbol);
|
||||
identifier++) {
|
||||
// If the identifier exists in `other`, then align it; otherwise insert it
|
||||
// assuming it is a new identifier. Search in `other` starting at position
|
||||
// `i` since the left of `i` is aligned.
|
||||
auto *findEnd =
|
||||
other.identifiers.begin() + other.getVarKindEnd(VarKind::Symbol);
|
||||
auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
|
||||
findEnd, *identifier);
|
||||
if (itr != findEnd) {
|
||||
std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
|
||||
} else {
|
||||
other.insertVar(VarKind::Symbol, i);
|
||||
other.getId(VarKind::Symbol, i) = *identifier;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
// Finally add identifiers that are in `other`, but not in `this` to `this`.
|
||||
for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; i++) {
|
||||
insertVar(VarKind::Symbol, i);
|
||||
getId(VarKind::Symbol, i) = other.getId(VarKind::Symbol, i);
|
||||
}
|
||||
}
|
||||
|
||||
void PresburgerSpace::print(llvm::raw_ostream &os) const {
|
||||
os << "Domain: " << getNumDomainVars() << ", "
|
||||
<< "Range: " << getNumRangeVars() << ", "
|
||||
|
||||
@@ -193,3 +193,73 @@ TEST(PresburgerSpaceTest, convertVarKind2) {
|
||||
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
|
||||
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
|
||||
}
|
||||
|
||||
TEST(PresburgerSpaceTest, mergeAndAlignSymbols) {
|
||||
PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
|
||||
space.resetIds();
|
||||
|
||||
PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
|
||||
otherSpace.resetIds();
|
||||
|
||||
// Attach identifiers.
|
||||
int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
|
||||
int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
|
||||
|
||||
space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
|
||||
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
|
||||
// Note the common identifier.
|
||||
space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
|
||||
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
|
||||
space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
|
||||
space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
|
||||
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
|
||||
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
|
||||
|
||||
otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
|
||||
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
|
||||
otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
|
||||
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
|
||||
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
|
||||
// Note the common identifier.
|
||||
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
|
||||
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
|
||||
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
|
||||
|
||||
space.mergeAndAlignSymbols(otherSpace);
|
||||
|
||||
// Check if merge & align is successful.
|
||||
// Check symbol var identifiers.
|
||||
EXPECT_EQ(4u, space.getNumSymbolVars());
|
||||
EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
|
||||
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
|
||||
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
|
||||
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
|
||||
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
|
||||
Identifier(&otherIdentifiers[5]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
|
||||
Identifier(&otherIdentifiers[7]));
|
||||
// Check that domain and range var identifiers are not affected.
|
||||
EXPECT_EQ(3u, space.getNumDomainVars());
|
||||
EXPECT_EQ(3u, space.getNumRangeVars());
|
||||
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
|
||||
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
|
||||
EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
|
||||
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
|
||||
EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
|
||||
EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
|
||||
EXPECT_EQ(3u, otherSpace.getNumDomainVars());
|
||||
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
|
||||
Identifier(&otherIdentifiers[0]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
|
||||
Identifier(&otherIdentifiers[1]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
|
||||
Identifier(&otherIdentifiers[2]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
|
||||
Identifier(&otherIdentifiers[3]));
|
||||
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
|
||||
Identifier(&otherIdentifiers[4]));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user