[mlir][vector] Update helpers in VectorEmulateNarrowType.cpp (nfc) (#131527)
Refactors the following pairs of helper hooks: * `dynamicallyInsertSubVector` + `staticallyInsertSubVector` * `dynamicallyExtractSubVector` + `staticallyExtractSubVector` These hooks are very similar, so I have unified the variable names and various conditions to make the actual differences clearer.
This commit is contained in:
committed by
GitHub
parent
3013458a79
commit
9768077de6
@@ -198,85 +198,156 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
|
||||
return *newMask;
|
||||
}
|
||||
|
||||
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
|
||||
/// emitting `vector.extract_strided_slice`.
|
||||
/// Extracts 1-D subvector from a 1-D vector.
|
||||
///
|
||||
/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
|
||||
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
|
||||
///
|
||||
/// vector<numElemsToExtract x !elemType>
|
||||
///
|
||||
/// (`!elType` is the element type of the source vector). As `offset` is a known
|
||||
/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
|
||||
///
|
||||
/// EXAMPLE:
|
||||
/// %res = vector.extract_strided_slice %src
|
||||
/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
|
||||
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
|
||||
Value source, int64_t frontOffset,
|
||||
int64_t subvecSize) {
|
||||
auto vectorType = cast<VectorType>(source.getType());
|
||||
assert(vectorType.getRank() == 1 && "expected 1-D source types");
|
||||
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
|
||||
Value src, int64_t offset,
|
||||
int64_t numElemsToExtract) {
|
||||
auto vectorType = cast<VectorType>(src.getType());
|
||||
assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
|
||||
assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
|
||||
"subvector out of bounds");
|
||||
|
||||
// do not need extraction if the subvector size is the same as the source
|
||||
if (vectorType.getNumElements() == subvecSize)
|
||||
return source;
|
||||
// When extracting all available elements, just use the source vector as the
|
||||
// result.
|
||||
if (vectorType.getNumElements() == numElemsToExtract)
|
||||
return src;
|
||||
|
||||
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
|
||||
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
|
||||
auto offsets = rewriter.getI64ArrayAttr({offset});
|
||||
auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
|
||||
auto strides = rewriter.getI64ArrayAttr({1});
|
||||
|
||||
auto resultVectorType =
|
||||
VectorType::get({subvecSize}, vectorType.getElementType());
|
||||
VectorType::get({numElemsToExtract}, vectorType.getElementType());
|
||||
return rewriter
|
||||
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
|
||||
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
|
||||
offsets, sizes, strides)
|
||||
->getResult(0);
|
||||
}
|
||||
|
||||
/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
|
||||
/// at `offset`. it is a wrapper function for emitting
|
||||
/// Inserts 1-D subvector into a 1-D vector.
|
||||
///
|
||||
/// Inserts the input rank-1 source vector into the destination vector starting
|
||||
/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
|
||||
/// `vector.insert_strided_slice`.
|
||||
///
|
||||
/// EXAMPLE:
|
||||
/// %res = vector.insert_strided_slice %src, %dest
|
||||
/// {offsets = [%offset], strides [1]}
|
||||
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
|
||||
Value src, Value dest, int64_t offset) {
|
||||
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
|
||||
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
|
||||
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
|
||||
"expected source and dest to be vector type");
|
||||
auto srcVecTy = cast<VectorType>(src.getType());
|
||||
auto destVecTy = cast<VectorType>(dest.getType());
|
||||
assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
|
||||
"expected source and dest to be rank-1 vector types");
|
||||
|
||||
// If overwritting the destination vector, just return the source.
|
||||
if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
|
||||
return src;
|
||||
|
||||
auto offsets = rewriter.getI64ArrayAttr({offset});
|
||||
auto strides = rewriter.getI64ArrayAttr({1});
|
||||
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
|
||||
return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
|
||||
dest, offsets, strides);
|
||||
}
|
||||
|
||||
/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
|
||||
/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
|
||||
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
|
||||
/// use it when `offset` cannot be folded into a constant value.
|
||||
/// Extracts 1-D subvector from a 1-D vector.
|
||||
///
|
||||
/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
|
||||
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
|
||||
///
|
||||
/// vector<numElemsToExtact x !elType>
|
||||
///
|
||||
/// (`!elType` is the element type of the source vector). As `offset` is assumed
|
||||
/// to be a _dynamic_ SSA value, this helper method generates a sequence of
|
||||
/// `vector.extract` + `vector.insert` pairs.
|
||||
///
|
||||
/// EXAMPLE:
|
||||
/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
|
||||
/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
|
||||
/// %c1 = arith.constant 1 : index
|
||||
/// %idx2 = arith.addi %offset, %c1 : index
|
||||
/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
|
||||
/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
|
||||
/// (...)
|
||||
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
|
||||
Value source, Value dest,
|
||||
Value src, Value dest,
|
||||
OpFoldResult offset,
|
||||
int64_t numElementsToExtract) {
|
||||
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
|
||||
for (int i = 0; i < numElementsToExtract; ++i) {
|
||||
int64_t numElemsToExtract) {
|
||||
auto srcVecTy = cast<VectorType>(src.getType());
|
||||
assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
|
||||
// NOTE: We are unable to take the offset into account in the following
|
||||
// assert, hence its still possible that the subvector is out-of-bounds even
|
||||
// if the condition is true.
|
||||
assert(numElemsToExtract <= srcVecTy.getNumElements() &&
|
||||
"subvector out of bounds");
|
||||
|
||||
// When extracting all available elements, just use the source vector as the
|
||||
// result.
|
||||
if (srcVecTy.getNumElements() == numElemsToExtract)
|
||||
return src;
|
||||
|
||||
for (int i = 0; i < numElemsToExtract; ++i) {
|
||||
Value extractLoc =
|
||||
(i == 0) ? offset.dyn_cast<Value>()
|
||||
: rewriter.create<arith::AddIOp>(
|
||||
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||
auto extractOp =
|
||||
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
|
||||
auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
|
||||
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
|
||||
}
|
||||
return dest;
|
||||
}
|
||||
|
||||
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
|
||||
/// Inserts 1-D subvector into a 1-D vector.
|
||||
///
|
||||
/// Inserts the input rank-1 source vector into the destination vector starting
|
||||
/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
|
||||
/// uses a sequence of `vector.extract` + `vector.insert` pairs.
|
||||
///
|
||||
/// EXAMPLE:
|
||||
/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
|
||||
/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
|
||||
/// %c1 = arith.constant 1 : index
|
||||
/// %idx2 = arith.addi %offset, %c1 : index
|
||||
/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
|
||||
/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
|
||||
/// (...)
|
||||
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
|
||||
Value source, Value dest,
|
||||
OpFoldResult destOffsetVar,
|
||||
size_t length) {
|
||||
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
|
||||
assert(length > 0 && "length must be greater than 0");
|
||||
Value destOffsetVal =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
Value src, Value dest,
|
||||
OpFoldResult offset,
|
||||
int64_t numElemsToInsert) {
|
||||
auto srcVecTy = cast<VectorType>(src.getType());
|
||||
auto destVecTy = cast<VectorType>(dest.getType());
|
||||
assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
|
||||
"expected source and dest to be rank-1 vector types");
|
||||
assert(numElemsToInsert > 0 &&
|
||||
"the number of elements to insert must be greater than 0");
|
||||
// NOTE: We are unable to take the offset into account in the following
|
||||
// assert, hence its still possible that the subvector is out-of-bounds even
|
||||
// if the condition is true.
|
||||
assert(numElemsToInsert <= destVecTy.getNumElements() &&
|
||||
"subvector out of bounds");
|
||||
|
||||
Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
|
||||
for (int64_t i = 0; i < numElemsToInsert; ++i) {
|
||||
auto insertLoc = i == 0
|
||||
? destOffsetVal
|
||||
: rewriter.create<arith::AddIOp>(
|
||||
loc, rewriter.getIndexType(), destOffsetVal,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
|
||||
auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
|
||||
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
|
||||
}
|
||||
return dest;
|
||||
|
||||
Reference in New Issue
Block a user