[IR2Vec] Simplifying creation of Embedder (#143999)

This change simplifies the API by removing the error handling complexity. 

- Changed `Embedder::create()` to return `std::unique_ptr<Embedder>` directly instead of `Expected<std::unique_ptr<Embedder>>`
- Updated documentation and tests to reflect the new API
- Added death test for invalid IR2Vec kind in debug mode
- In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - #141817)
This commit is contained in:
S. VenkataKeerthy
2025-06-30 18:24:08 -07:00
committed by GitHub
parent 24c4bba076
commit 9438048816
6 changed files with 32 additions and 55 deletions

View File

@@ -488,14 +488,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
std::unique_ptr<ir2vec::Embedder> Emb =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
if (auto Err = EmbOrErr.takeError()) {
// Handle error in embedder creation
return;
}
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
3. **Compute and Access Embeddings**:
Call ``getFunctionVector()`` to get the embedding for the function.

View File

@@ -171,7 +171,7 @@ public:
virtual ~Embedder() = default;
/// Factory method to create an Embedder object.
LLVM_ABI static Expected<std::unique_ptr<Embedder>>
LLVM_ABI static std::unique_ptr<Embedder>
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
/// Returns a map containing instructions and the corresponding embeddings for

View File

@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
// We instantiate the IR2Vec embedder each time, as having an unique
// pointer to the embedder as member of the class would make it
// non-copyable. Instantiating the embedder in itself is not costly.
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
*BB.getParent(), *IR2VecVocab);
if (Error Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
EI.message());
});
if (!Embedder) {
BB.getContext().emitError("Error creating IR2Vec embeddings");
return;
}
auto Embedder = std::move(*EmbOrErr);
const auto &BBEmbedding = Embedder->getBBVector(BB);
// Subtract BBEmbedding from Function embedding if the direction is -1,
// and add it if the direction is +1.

View File

@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
Expected<std::unique_ptr<Embedder>>
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
llvm_unreachable("Unknown IR2Vec kind");
return nullptr;
}
// FIXME: Currently lookups are string based. Use numeric Keys
@@ -384,17 +385,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
auto Vocab = IR2VecVocabResult.getVocabulary();
for (Function &F : M) {
Expected<std::unique_ptr<Embedder>> EmbOrErr =
std::unique_ptr<Embedder> Emb =
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
if (auto Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
});
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
}
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
Emb->getFunctionVector().print(OS);

View File

@@ -127,10 +127,9 @@ protected:
}
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
auto EmbResult =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
EXPECT_TRUE(static_cast<bool>(EmbResult));
return std::move(*EmbResult);
auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
EXPECT_TRUE(static_cast<bool>(Emb));
return std::move(Emb);
}
};

View File

@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
auto *Emb = Result->get();
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_NE(Emb, nullptr);
}
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
// static_cast an invalid int to IR2VecKind
// static_cast an invalid int to IR2VecKind
#ifndef NDEBUG
#if GTEST_HAS_DEATH_TEST
EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
"Unknown IR2Vec kind");
#endif // GTEST_HAS_DEATH_TEST
#else
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
std::string ErrMsg;
llvm::handleAllErrors(
Result.takeError(),
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
#endif // NDEBUG
}
TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ protected:
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;
float OriginalOpcWeight = ::OpcWeight;
float OriginalTypeWeight = ::TypeWeight;
float OriginalArgWeight = ::ArgWeight;
void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ protected:
};
TEST_F(IR2VecTestFixture, GetInstVecMap) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVector) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();