[IR2Vec] Add out-of-place arithmetic operators to Embedding class (#145118)
This PR adds out-of-place arithmetic operators (`+`, `-`, `*`) to the `Embedding` class in IR2Vec, complementing the existing in-place operators (`+=`, `-=`, `*=`). Tests have been added to verify the functionality of these new operators. (Tracking issue - #141817)
This commit is contained in:
committed by
GitHub
parent
efe0deae3f
commit
119292c40b
@@ -107,9 +107,12 @@ public:
|
||||
const std::vector<double> &getData() const { return Data; }
|
||||
|
||||
/// Arithmetic operators
|
||||
Embedding &operator+=(const Embedding &RHS);
|
||||
Embedding &operator-=(const Embedding &RHS);
|
||||
Embedding &operator*=(double Factor);
|
||||
LLVM_ABI Embedding &operator+=(const Embedding &RHS);
|
||||
LLVM_ABI Embedding operator+(const Embedding &RHS) const;
|
||||
LLVM_ABI Embedding &operator-=(const Embedding &RHS);
|
||||
LLVM_ABI Embedding operator-(const Embedding &RHS) const;
|
||||
LLVM_ABI Embedding &operator*=(double Factor);
|
||||
LLVM_ABI Embedding operator*(double Factor) const;
|
||||
|
||||
/// Adds Src Embedding scaled by Factor with the called Embedding.
|
||||
/// Called_Embedding += Src * Factor
|
||||
|
||||
@@ -70,7 +70,6 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
|
||||
// ==----------------------------------------------------------------------===//
|
||||
// Embedding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Embedding &Embedding::operator+=(const Embedding &RHS) {
|
||||
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
|
||||
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
|
||||
@@ -78,6 +77,12 @@ Embedding &Embedding::operator+=(const Embedding &RHS) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Embedding Embedding::operator+(const Embedding &RHS) const {
|
||||
Embedding Result(*this);
|
||||
Result += RHS;
|
||||
return Result;
|
||||
}
|
||||
|
||||
Embedding &Embedding::operator-=(const Embedding &RHS) {
|
||||
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
|
||||
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
|
||||
@@ -85,12 +90,24 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Embedding Embedding::operator-(const Embedding &RHS) const {
|
||||
Embedding Result(*this);
|
||||
Result -= RHS;
|
||||
return Result;
|
||||
}
|
||||
|
||||
Embedding &Embedding::operator*=(double Factor) {
|
||||
std::transform(this->begin(), this->end(), this->begin(),
|
||||
[Factor](double Elem) { return Elem * Factor; });
|
||||
return *this;
|
||||
}
|
||||
|
||||
Embedding Embedding::operator*(double Factor) const {
|
||||
Embedding Result(*this);
|
||||
Result *= Factor;
|
||||
return Result;
|
||||
}
|
||||
|
||||
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
|
||||
assert(this->size() == Src.size() && "Vectors must have the same dimension");
|
||||
for (size_t Itr = 0; Itr < this->size(); ++Itr)
|
||||
|
||||
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, AddVectorsOutOfPlace) {
|
||||
Embedding E1 = {1.0, 2.0, 3.0};
|
||||
Embedding E2 = {0.5, 1.5, -1.0};
|
||||
|
||||
Embedding E3 = E1 + E2;
|
||||
EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
|
||||
|
||||
// Check that E1 and E2 are unchanged
|
||||
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
|
||||
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, AddVectors) {
|
||||
Embedding E1 = {1.0, 2.0, 3.0};
|
||||
Embedding E2 = {0.5, 1.5, -1.0};
|
||||
@@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) {
|
||||
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
|
||||
Embedding E1 = {1.0, 2.0, 3.0};
|
||||
Embedding E2 = {0.5, 1.5, -1.0};
|
||||
|
||||
Embedding E3 = E1 - E2;
|
||||
EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
|
||||
|
||||
// Check that E1 and E2 are unchanged
|
||||
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
|
||||
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, SubtractVectors) {
|
||||
Embedding E1 = {1.0, 2.0, 3.0};
|
||||
Embedding E2 = {0.5, 1.5, -1.0};
|
||||
@@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) {
|
||||
EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
|
||||
Embedding E1 = {1.0, 2.0, 3.0};
|
||||
Embedding E2 = E1 * 0.5f;
|
||||
EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
|
||||
|
||||
// Check that E1 is unchanged
|
||||
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, AddScaledVector) {
|
||||
Embedding E1 = {1.0, 2.0, 3.0};
|
||||
Embedding E2 = {2.0, 0.5, -1.0};
|
||||
@@ -180,6 +213,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
|
||||
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
|
||||
Embedding E1 = {1.0, 2.0};
|
||||
Embedding E2 = {1.0};
|
||||
EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
|
||||
}
|
||||
|
||||
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
|
||||
Embedding E1 = {1.0, 2.0};
|
||||
Embedding E2 = {1.0};
|
||||
|
||||
Reference in New Issue
Block a user