[IR2Vec] Simplifying creation of Embedder (#143999)
This change simplifies the API by removing the error handling complexity. - Changed `Embedder::create()` to return `std::unique_ptr<Embedder>` directly instead of `Expected<std::unique_ptr<Embedder>>` - Updated documentation and tests to reflect the new API - Added death test for invalid IR2Vec kind in debug mode - In release mode, simply returns nullptr for invalid kinds instead of creating an error (Tracking issue - #141817)
This commit is contained in:
committed by
GitHub
parent
24c4bba076
commit
9438048816
@@ -488,14 +488,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
|
||||
|
||||
// Assuming F is an llvm::Function&
|
||||
// For example, using IR2VecKind::Symbolic:
|
||||
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
|
||||
std::unique_ptr<ir2vec::Embedder> Emb =
|
||||
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
|
||||
|
||||
if (auto Err = EmbOrErr.takeError()) {
|
||||
// Handle error in embedder creation
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
|
||||
|
||||
3. **Compute and Access Embeddings**:
|
||||
Call ``getFunctionVector()`` to get the embedding for the function.
|
||||
|
||||
@@ -171,7 +171,7 @@ public:
|
||||
virtual ~Embedder() = default;
|
||||
|
||||
/// Factory method to create an Embedder object.
|
||||
LLVM_ABI static Expected<std::unique_ptr<Embedder>>
|
||||
LLVM_ABI static std::unique_ptr<Embedder>
|
||||
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
|
||||
|
||||
/// Returns a map containing instructions and the corresponding embeddings for
|
||||
|
||||
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
|
||||
// We instantiate the IR2Vec embedder each time, as having an unique
|
||||
// pointer to the embedder as member of the class would make it
|
||||
// non-copyable. Instantiating the embedder in itself is not costly.
|
||||
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
|
||||
auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
|
||||
*BB.getParent(), *IR2VecVocab);
|
||||
if (Error Err = EmbOrErr.takeError()) {
|
||||
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
|
||||
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
|
||||
EI.message());
|
||||
});
|
||||
if (!Embedder) {
|
||||
BB.getContext().emitError("Error creating IR2Vec embeddings");
|
||||
return;
|
||||
}
|
||||
auto Embedder = std::move(*EmbOrErr);
|
||||
const auto &BBEmbedding = Embedder->getBBVector(BB);
|
||||
// Subtract BBEmbedding from Function embedding if the direction is -1,
|
||||
// and add it if the direction is +1.
|
||||
|
||||
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
|
||||
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
|
||||
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
|
||||
|
||||
Expected<std::unique_ptr<Embedder>>
|
||||
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
|
||||
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
|
||||
const Vocab &Vocabulary) {
|
||||
switch (Mode) {
|
||||
case IR2VecKind::Symbolic:
|
||||
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
|
||||
}
|
||||
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
|
||||
llvm_unreachable("Unknown IR2Vec kind");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// FIXME: Currently lookups are string based. Use numeric Keys
|
||||
@@ -384,17 +385,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
|
||||
|
||||
auto Vocab = IR2VecVocabResult.getVocabulary();
|
||||
for (Function &F : M) {
|
||||
Expected<std::unique_ptr<Embedder>> EmbOrErr =
|
||||
std::unique_ptr<Embedder> Emb =
|
||||
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
|
||||
if (auto Err = EmbOrErr.takeError()) {
|
||||
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
|
||||
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
|
||||
});
|
||||
if (!Emb) {
|
||||
OS << "Error creating IR2Vec embeddings \n";
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
|
||||
|
||||
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
|
||||
OS << "Function vector: ";
|
||||
Emb->getFunctionVector().print(OS);
|
||||
|
||||
@@ -127,10 +127,9 @@ protected:
|
||||
}
|
||||
|
||||
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
|
||||
auto EmbResult =
|
||||
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
|
||||
EXPECT_TRUE(static_cast<bool>(EmbResult));
|
||||
return std::move(*EmbResult);
|
||||
auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
|
||||
EXPECT_TRUE(static_cast<bool>(Emb));
|
||||
return std::move(Emb);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
|
||||
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
|
||||
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
|
||||
|
||||
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
EXPECT_TRUE(static_cast<bool>(Result));
|
||||
|
||||
auto *Emb = Result->get();
|
||||
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
EXPECT_NE(Emb, nullptr);
|
||||
}
|
||||
|
||||
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
|
||||
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
|
||||
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
|
||||
|
||||
// static_cast an invalid int to IR2VecKind
|
||||
// static_cast an invalid int to IR2VecKind
|
||||
#ifndef NDEBUG
|
||||
#if GTEST_HAS_DEATH_TEST
|
||||
EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
|
||||
"Unknown IR2Vec kind");
|
||||
#endif // GTEST_HAS_DEATH_TEST
|
||||
#else
|
||||
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
|
||||
EXPECT_FALSE(static_cast<bool>(Result));
|
||||
|
||||
std::string ErrMsg;
|
||||
llvm::handleAllErrors(
|
||||
Result.takeError(),
|
||||
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
|
||||
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
|
||||
#endif // NDEBUG
|
||||
}
|
||||
|
||||
TEST(IR2VecTest, LookupVocab) {
|
||||
@@ -298,10 +296,6 @@ protected:
|
||||
Instruction *AddInst = nullptr;
|
||||
Instruction *RetInst = nullptr;
|
||||
|
||||
float OriginalOpcWeight = ::OpcWeight;
|
||||
float OriginalTypeWeight = ::TypeWeight;
|
||||
float OriginalArgWeight = ::ArgWeight;
|
||||
|
||||
void SetUp() override {
|
||||
V = {{"add", {1.0, 2.0}},
|
||||
{"integerTy", {0.25, 0.25}},
|
||||
@@ -325,9 +319,8 @@ protected:
|
||||
};
|
||||
|
||||
TEST_F(IR2VecTestFixture, GetInstVecMap) {
|
||||
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Result));
|
||||
auto Emb = std::move(*Result);
|
||||
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Emb));
|
||||
|
||||
const auto &InstMap = Emb->getInstVecMap();
|
||||
|
||||
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
|
||||
}
|
||||
|
||||
TEST_F(IR2VecTestFixture, GetBBVecMap) {
|
||||
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Result));
|
||||
auto Emb = std::move(*Result);
|
||||
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Emb));
|
||||
|
||||
const auto &BBMap = Emb->getBBVecMap();
|
||||
|
||||
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
|
||||
}
|
||||
|
||||
TEST_F(IR2VecTestFixture, GetBBVector) {
|
||||
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Result));
|
||||
auto Emb = std::move(*Result);
|
||||
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Emb));
|
||||
|
||||
const auto &BBVec = Emb->getBBVector(*BB);
|
||||
|
||||
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
|
||||
}
|
||||
|
||||
TEST_F(IR2VecTestFixture, GetFunctionVector) {
|
||||
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Result));
|
||||
auto Emb = std::move(*Result);
|
||||
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
||||
ASSERT_TRUE(static_cast<bool>(Emb));
|
||||
|
||||
const auto &FuncVec = Emb->getFunctionVector();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user