[mlir][vector] Add vector.from_elements op (#95938)
This commit adds a new operation to the vector dialect: `vector.from_elements` The op constructs a new vector from a given list of scalar values. It is similar to `tensor.from_elements`. ```mlir %0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector<2x3xf32> ``` Constructing a new vector from elements was tedious before this op existed: a typical way was to define an `arith.constant ... : vector<...>`, followed by a chain of `vector.insert`. Folders/canonicalizations are added that can fold `vector.extract` ops and convert the `vector.from_elements` op into a `vector.splat` op. The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence of scalar insertions in the form of `llvm.insertelement`. Only 0-D and 1-D vectors are currently supported in the LLVM lowering.
This commit is contained in:
committed by
GitHub
parent
bacbf26b4c
commit
c6ff2446a4
@@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a `vector.from_elements`.
|
||||
struct VectorFromElementsLowering
|
||||
: public ConvertOpToLLVMPattern<vector::FromElementsOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = fromElementsOp.getLoc();
|
||||
VectorType vectorType = fromElementsOp.getType();
|
||||
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
|
||||
// Such ops should be handled in the same way as vector.insert.
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.notifyMatchFailure(fromElementsOp,
|
||||
"rank > 1 vectors are not supported");
|
||||
Type llvmType = typeConverter->convertType(vectorType);
|
||||
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
||||
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
|
||||
result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
|
||||
rewriter.replaceOp(fromElementsOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
@@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorSplatOpLowering, VectorSplatNdOpLowering,
|
||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
||||
MaskedReductionOpConversion, VectorInterleaveOpLowering,
|
||||
VectorDeinterleaveOpLowering>(converter);
|
||||
VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
|
||||
converter);
|
||||
// Transfer ops with rank > 1 are handled by VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user