[mlir][vector] Add vector.from_elements op (#95938)

This commit adds a new operation to the vector dialect:
`vector.from_elements`

The op constructs a new vector from a given list of scalar values. It is
similar to `tensor.from_elements`.
```mlir
%0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector<2x3xf32>
```

Constructing a new vector from elements was tedious before this op
existed: a typical way was to define an `arith.constant ... :
vector<...>`, followed by a chain of `vector.insert`.

Folders/canonicalizations are added that can fold `vector.extract` ops
and convert the `vector.from_elements` op into a `vector.splat` op.

The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence
of scalar insertions in the form of `llvm.insertelement`. Only 0-D and
1-D vectors are currently supported in the LLVM lowering.
This commit is contained in:
Matthias Springer
2024-06-19 09:58:37 +02:00
committed by GitHub
parent bacbf26b4c
commit c6ff2446a4
7 changed files with 305 additions and 4 deletions

View File

@@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering
}
};
/// Conversion pattern for a `vector.from_elements`.
struct VectorFromElementsLowering
: public ConvertOpToLLVMPattern<vector::FromElementsOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
// Such ops should be handled in the same way as vector.insert.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
};
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering>(converter);
VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}