This PR adds out-of-place arithmetic operators (`+`, `-`, `*`) to the `Embedding` class in IR2Vec, complementing the existing in-place operators (`+=`, `-=`, `*=`). Tests have been added to verify the functionality of these new operators. (Tracking issue - #141817)
440 lines
12 KiB
C++
440 lines
12 KiB
C++
//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Analysis/IR2Vec.h"
|
|
#include "llvm/IR/Constants.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/LLVMContext.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Error.h"
|
|
#include "llvm/Support/JSON.h"
|
|
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
#include <map>
|
|
#include <vector>
|
|
|
|
using namespace llvm;
|
|
using namespace ir2vec;
|
|
using namespace ::testing;
|
|
|
|
namespace {
|
|
|
|
class TestableEmbedder : public Embedder {
|
|
public:
|
|
TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
|
|
void computeEmbeddings() const override {}
|
|
void computeEmbeddings(const BasicBlock &BB) const override {}
|
|
using Embedder::lookupVocab;
|
|
};
|
|
|
|
TEST(EmbeddingTest, ConstructorsAndAccessors) {
|
|
// Default constructor
|
|
{
|
|
Embedding E;
|
|
EXPECT_TRUE(E.empty());
|
|
EXPECT_EQ(E.size(), 0u);
|
|
}
|
|
|
|
// Constructor with const std::vector<double>&
|
|
{
|
|
std::vector<double> Data = {1.0, 2.0, 3.0};
|
|
Embedding E(Data);
|
|
EXPECT_FALSE(E.empty());
|
|
ASSERT_THAT(E, SizeIs(3u));
|
|
EXPECT_THAT(E.getData(), ElementsAre(1.0, 2.0, 3.0));
|
|
EXPECT_EQ(E[0], 1.0);
|
|
EXPECT_EQ(E[1], 2.0);
|
|
EXPECT_EQ(E[2], 3.0);
|
|
}
|
|
|
|
// Constructor with std::vector<double>&&
|
|
{
|
|
Embedding E(std::vector<double>({4.0, 5.0}));
|
|
ASSERT_THAT(E, SizeIs(2u));
|
|
EXPECT_THAT(E.getData(), ElementsAre(4.0, 5.0));
|
|
}
|
|
|
|
// Constructor with std::initializer_list<double>
|
|
{
|
|
Embedding E({6.0, 7.0, 8.0, 9.0});
|
|
ASSERT_THAT(E, SizeIs(4u));
|
|
EXPECT_THAT(E.getData(), ElementsAre(6.0, 7.0, 8.0, 9.0));
|
|
EXPECT_EQ(E[0], 6.0);
|
|
E[0] = 6.5;
|
|
EXPECT_EQ(E[0], 6.5);
|
|
}
|
|
|
|
// Constructor with size_t
|
|
{
|
|
Embedding E(5);
|
|
ASSERT_THAT(E, SizeIs(5u));
|
|
EXPECT_THAT(E.getData(), ElementsAre(0.0, 0.0, 0.0, 0.0, 0.0));
|
|
}
|
|
|
|
// Constructor with size_t and double
|
|
{
|
|
Embedding E(5, 1.5);
|
|
ASSERT_THAT(E, SizeIs(5u));
|
|
EXPECT_THAT(E.getData(), ElementsAre(1.5, 1.5, 1.5, 1.5, 1.5));
|
|
}
|
|
|
|
// Test iterators
|
|
{
|
|
Embedding E({6.5, 7.0, 8.0, 9.0});
|
|
std::vector<double> VecE;
|
|
for (double Val : E) {
|
|
VecE.push_back(Val);
|
|
}
|
|
EXPECT_THAT(VecE, ElementsAre(6.5, 7.0, 8.0, 9.0));
|
|
|
|
const Embedding CE = E;
|
|
std::vector<double> VecCE;
|
|
for (const double &Val : CE) {
|
|
VecCE.push_back(Val);
|
|
}
|
|
EXPECT_THAT(VecCE, ElementsAre(6.5, 7.0, 8.0, 9.0));
|
|
|
|
EXPECT_EQ(*E.begin(), 6.5);
|
|
EXPECT_EQ(*(E.end() - 1), 9.0);
|
|
EXPECT_EQ(*CE.cbegin(), 6.5);
|
|
EXPECT_EQ(*(CE.cend() - 1), 9.0);
|
|
}
|
|
}
|
|
|
|
TEST(EmbeddingTest, AddVectorsOutOfPlace) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = {0.5, 1.5, -1.0};
|
|
|
|
Embedding E3 = E1 + E2;
|
|
EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
|
|
|
|
// Check that E1 and E2 are unchanged
|
|
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
|
|
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
|
}
|
|
|
|
TEST(EmbeddingTest, AddVectors) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = {0.5, 1.5, -1.0};
|
|
|
|
E1 += E2;
|
|
EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
|
|
|
|
// Check that E2 is unchanged
|
|
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
|
}
|
|
|
|
TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = {0.5, 1.5, -1.0};
|
|
|
|
Embedding E3 = E1 - E2;
|
|
EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
|
|
|
|
// Check that E1 and E2 are unchanged
|
|
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
|
|
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
|
}
|
|
|
|
TEST(EmbeddingTest, SubtractVectors) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = {0.5, 1.5, -1.0};
|
|
|
|
E1 -= E2;
|
|
EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
|
|
|
|
// Check that E2 is unchanged
|
|
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
|
}
|
|
|
|
TEST(EmbeddingTest, ScaleVector) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
E1 *= 0.5f;
|
|
EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
|
|
}
|
|
|
|
TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = E1 * 0.5f;
|
|
EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
|
|
|
|
// Check that E1 is unchanged
|
|
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
|
|
}
|
|
|
|
TEST(EmbeddingTest, AddScaledVector) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = {2.0, 0.5, -1.0};
|
|
|
|
E1.scaleAndAdd(E2, 0.5f);
|
|
EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
|
|
|
|
// Check that E2 is unchanged
|
|
EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
|
|
}
|
|
|
|
TEST(EmbeddingTest, ApproximatelyEqual) {
|
|
Embedding E1 = {1.0, 2.0, 3.0};
|
|
Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
|
|
EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
|
|
|
|
Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
|
|
EXPECT_FALSE(E1.approximatelyEquals(E3, 1e-6));
|
|
EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
|
|
|
|
Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
|
|
EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
|
|
|
|
Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
|
|
EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside, 1e-6));
|
|
|
|
Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
|
|
EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
|
|
|
|
Embedding E5 = {1.0, 2.0, 3.0};
|
|
EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
|
|
EXPECT_TRUE(E1.approximatelyEquals(E5));
|
|
}
|
|
|
|
#if GTEST_HAS_DEATH_TEST
|
|
#ifndef NDEBUG
|
|
TEST(EmbeddingTest, AccessOutOfBounds) {
|
|
Embedding E = {1.0, 2.0, 3.0};
|
|
EXPECT_DEATH(E[3], "Index out of bounds");
|
|
EXPECT_DEATH(E[-1], "Index out of bounds");
|
|
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
|
|
}
|
|
|
|
TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
|
|
Embedding E1 = {1.0, 2.0};
|
|
Embedding E2 = {1.0};
|
|
EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
|
|
}
|
|
|
|
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
|
|
Embedding E1 = {1.0, 2.0};
|
|
Embedding E2 = {1.0};
|
|
EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
|
|
}
|
|
|
|
TEST(EmbeddingTest, MismatchedDimensionsSubtractVectors) {
|
|
Embedding E1 = {1.0, 2.0};
|
|
Embedding E2 = {1.0};
|
|
EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
|
|
}
|
|
|
|
TEST(EmbeddingTest, MismatchedDimensionsAddScaledVector) {
|
|
Embedding E1 = {1.0, 2.0};
|
|
Embedding E2 = {1.0};
|
|
EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
|
|
"Vectors must have the same dimension");
|
|
}
|
|
|
|
TEST(EmbeddingTest, MismatchedDimensionsApproximatelyEqual) {
|
|
Embedding E1 = {1.0, 2.0};
|
|
Embedding E2 = {1.010};
|
|
EXPECT_DEATH(E1.approximatelyEquals(E2),
|
|
"Vectors must have the same dimension");
|
|
}
|
|
#endif // NDEBUG
|
|
#endif // GTEST_HAS_DEATH_TEST
|
|
|
|
TEST(IR2VecTest, CreateSymbolicEmbedder) {
|
|
Vocab V = {{"foo", {1.0, 2.0}}};
|
|
|
|
LLVMContext Ctx;
|
|
Module M("M", Ctx);
|
|
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
|
|
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
|
|
|
|
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
|
EXPECT_NE(Emb, nullptr);
|
|
}
|
|
|
|
TEST(IR2VecTest, CreateInvalidMode) {
|
|
Vocab V = {{"foo", {1.0, 2.0}}};
|
|
|
|
LLVMContext Ctx;
|
|
Module M("M", Ctx);
|
|
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
|
|
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
|
|
|
|
// static_cast an invalid int to IR2VecKind
|
|
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
|
|
EXPECT_FALSE(static_cast<bool>(Result));
|
|
}
|
|
|
|
TEST(IR2VecTest, LookupVocab) {
|
|
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
|
|
LLVMContext Ctx;
|
|
Module M("M", Ctx);
|
|
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
|
|
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
|
|
|
|
TestableEmbedder E(*F, V);
|
|
auto V_foo = E.lookupVocab("foo");
|
|
EXPECT_EQ(V_foo.size(), 2u);
|
|
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
|
|
|
|
auto V_missing = E.lookupVocab("missing");
|
|
EXPECT_EQ(V_missing.size(), 2u);
|
|
EXPECT_THAT(V_missing, ElementsAre(0.0, 0.0));
|
|
}
|
|
|
|
TEST(IR2VecTest, ZeroDimensionEmbedding) {
|
|
Embedding E1;
|
|
Embedding E2;
|
|
// Should be no-op, but not crash
|
|
E1 += E2;
|
|
E1 -= E2;
|
|
E1.scaleAndAdd(E2, 1.0f);
|
|
EXPECT_TRUE(E1.empty());
|
|
}
|
|
|
|
TEST(IR2VecTest, IR2VecVocabResultValidity) {
|
|
// Default constructed is invalid
|
|
IR2VecVocabResult invalidResult;
|
|
EXPECT_FALSE(invalidResult.isValid());
|
|
#if GTEST_HAS_DEATH_TEST
|
|
#ifndef NDEBUG
|
|
EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid");
|
|
EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
|
|
#endif // NDEBUG
|
|
#endif // GTEST_HAS_DEATH_TEST
|
|
|
|
// Valid vocab
|
|
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
|
|
IR2VecVocabResult validResult(std::move(V));
|
|
EXPECT_TRUE(validResult.isValid());
|
|
EXPECT_EQ(validResult.getDimension(), 2u);
|
|
}
|
|
|
|
// Fixture for IR2Vec tests requiring IR setup.
|
|
class IR2VecTestFixture : public ::testing::Test {
|
|
protected:
|
|
Vocab V;
|
|
LLVMContext Ctx;
|
|
std::unique_ptr<Module> M;
|
|
Function *F = nullptr;
|
|
BasicBlock *BB = nullptr;
|
|
Instruction *AddInst = nullptr;
|
|
Instruction *RetInst = nullptr;
|
|
|
|
void SetUp() override {
|
|
V = {{"add", {1.0, 2.0}},
|
|
{"integerTy", {0.25, 0.25}},
|
|
{"constant", {0.04, 0.06}},
|
|
{"variable", {0.0, 0.0}},
|
|
{"unknownTy", {0.0, 0.0}}};
|
|
|
|
// Setup IR
|
|
M = std::make_unique<Module>("TestM", Ctx);
|
|
FunctionType *FTy = FunctionType::get(
|
|
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
|
|
false);
|
|
F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get());
|
|
BB = BasicBlock::Create(Ctx, "entry", F);
|
|
Argument *Arg = F->getArg(0);
|
|
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
|
|
|
|
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
|
|
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
|
|
}
|
|
};
|
|
|
|
TEST_F(IR2VecTestFixture, GetInstVecMap) {
|
|
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
|
ASSERT_TRUE(static_cast<bool>(Emb));
|
|
|
|
const auto &InstMap = Emb->getInstVecMap();
|
|
|
|
EXPECT_EQ(InstMap.size(), 2u);
|
|
EXPECT_TRUE(InstMap.count(AddInst));
|
|
EXPECT_TRUE(InstMap.count(RetInst));
|
|
|
|
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
|
|
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
|
|
|
|
// Check values for add: {1.29, 2.31}
|
|
EXPECT_THAT(InstMap.at(AddInst),
|
|
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
|
|
|
|
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
|
|
// vocab
|
|
EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
|
|
}
|
|
|
|
TEST_F(IR2VecTestFixture, GetBBVecMap) {
|
|
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
|
ASSERT_TRUE(static_cast<bool>(Emb));
|
|
|
|
const auto &BBMap = Emb->getBBVecMap();
|
|
|
|
EXPECT_EQ(BBMap.size(), 1u);
|
|
EXPECT_TRUE(BBMap.count(BB));
|
|
EXPECT_EQ(BBMap.at(BB).size(), 2u);
|
|
|
|
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
|
|
// {1.29, 2.31}
|
|
EXPECT_THAT(BBMap.at(BB),
|
|
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
|
|
}
|
|
|
|
TEST_F(IR2VecTestFixture, GetBBVector) {
|
|
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
|
ASSERT_TRUE(static_cast<bool>(Emb));
|
|
|
|
const auto &BBVec = Emb->getBBVector(*BB);
|
|
|
|
EXPECT_EQ(BBVec.size(), 2u);
|
|
EXPECT_THAT(BBVec,
|
|
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
|
|
}
|
|
|
|
TEST_F(IR2VecTestFixture, GetFunctionVector) {
|
|
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
|
|
ASSERT_TRUE(static_cast<bool>(Emb));
|
|
|
|
const auto &FuncVec = Emb->getFunctionVector();
|
|
|
|
EXPECT_EQ(FuncVec.size(), 2u);
|
|
|
|
// Function vector should match BB vector (only one BB): {1.29, 2.31}
|
|
EXPECT_THAT(FuncVec,
|
|
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
|
|
}
|
|
|
|
TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
|
|
Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
|
|
Vocab ExpectedVocab = InitialVocab;
|
|
unsigned ExpectedDim = InitialVocab.begin()->second.size();
|
|
|
|
IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
|
|
|
|
LLVMContext TestCtx;
|
|
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
|
|
ModuleAnalysisManager MAM;
|
|
IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
|
|
|
|
EXPECT_TRUE(Result.isValid());
|
|
ASSERT_FALSE(Result.getVocabulary().empty());
|
|
EXPECT_EQ(Result.getDimension(), ExpectedDim);
|
|
|
|
const auto &ResultVocab = Result.getVocabulary();
|
|
EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
|
|
for (const auto &pair : ExpectedVocab) {
|
|
EXPECT_TRUE(ResultVocab.count(pair.first));
|
|
EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
|
|
}
|
|
}
|
|
|
|
} // end anonymous namespace
|