This is to improve consistency within the SPIR-V dialect and make these ops a bit shorter. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D130280
389 lines
14 KiB
C++
389 lines
14 KiB
C++
//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements patterns to convert Vector dialect to SPIRV dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
|
|
/// Gets the first integer value from `attr`, assuming it is an integer array
|
|
/// attribute.
|
|
static uint64_t getFirstIntValue(ArrayAttr attr) {
|
|
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct VectorBitcastConvert final
|
|
: public OpConversionPattern<vector::BitCastOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
if (dstType == adaptor.getSource().getType())
|
|
rewriter.replaceOp(bitcastOp, adaptor.getSource());
|
|
else
|
|
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
|
|
adaptor.getSource());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorBroadcastConvert final
|
|
: public OpConversionPattern<vector::BroadcastOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (broadcastOp.getSource().getType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
|
|
return failure();
|
|
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
|
|
adaptor.getSource());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
|
|
broadcastOp, broadcastOp.getVectorType(), source);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractOpConvert final
|
|
: public OpConversionPattern<vector::ExtractOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Only support extracting a scalar value now.
|
|
VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
|
|
if (resultVectorType && resultVectorType.getNumElements() > 1)
|
|
return failure();
|
|
|
|
auto dstType = getTypeConverter()->convertType(extractOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
|
|
rewriter.replaceOp(extractOp, adaptor.getVector());
|
|
return success();
|
|
}
|
|
|
|
int32_t id = getFirstIntValue(extractOp.getPosition());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
|
|
extractOp, adaptor.getVector(), id);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractStridedSliceOpConvert final
|
|
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto dstType = getTypeConverter()->convertType(extractOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
|
|
uint64_t size = getFirstIntValue(extractOp.getSizes());
|
|
uint64_t stride = getFirstIntValue(extractOp.getStrides());
|
|
if (stride != 1)
|
|
return failure();
|
|
|
|
Value srcVector = adaptor.getOperands().front();
|
|
|
|
// Extract vector<1xT> case.
|
|
if (dstType.isa<spirv::ScalarType>()) {
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
|
|
srcVector, offset);
|
|
return success();
|
|
}
|
|
|
|
SmallVector<int32_t, 2> indices(size);
|
|
std::iota(indices.begin(), indices.end(), offset);
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
|
|
extractOp, dstType, srcVector, srcVector,
|
|
rewriter.getI32ArrayAttr(indices));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type dstType = getTypeConverter()->convertType(fmaOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
|
|
fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertOpConvert final
|
|
: public OpConversionPattern<vector::InsertOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Special case for inserting scalar values into size-1 vectors.
|
|
if (insertOp.getSourceType().isIntOrFloat() &&
|
|
insertOp.getDestVectorType().getNumElements() == 1) {
|
|
rewriter.replaceOp(insertOp, adaptor.getSource());
|
|
return success();
|
|
}
|
|
|
|
if (insertOp.getSourceType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
|
|
return failure();
|
|
int32_t id = getFirstIntValue(insertOp.getPosition());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
|
|
insertOp, adaptor.getSource(), adaptor.getDest(), id);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractElementOpConvert final
|
|
: public OpConversionPattern<vector::ExtractElementOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
|
|
extractElementOp, extractElementOp.getType(), adaptor.getVector(),
|
|
extractElementOp.getPosition());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertElementOpConvert final
|
|
: public OpConversionPattern<vector::InsertElementOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
|
|
insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
|
|
adaptor.getSource(), insertElementOp.getPosition());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertStridedSliceOpConvert final
|
|
: public OpConversionPattern<vector::InsertStridedSliceOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value srcVector = adaptor.getOperands().front();
|
|
Value dstVector = adaptor.getOperands().back();
|
|
|
|
uint64_t stride = getFirstIntValue(insertOp.getStrides());
|
|
if (stride != 1)
|
|
return failure();
|
|
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
|
|
|
|
if (srcVector.getType().isa<spirv::ScalarType>()) {
|
|
assert(!dstVector.getType().isa<spirv::ScalarType>());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
|
|
insertOp, dstVector.getType(), srcVector, dstVector,
|
|
rewriter.getI32ArrayAttr(offset));
|
|
return success();
|
|
}
|
|
|
|
uint64_t totalSize =
|
|
dstVector.getType().cast<VectorType>().getNumElements();
|
|
uint64_t insertSize =
|
|
srcVector.getType().cast<VectorType>().getNumElements();
|
|
|
|
SmallVector<int32_t, 2> indices(totalSize);
|
|
std::iota(indices.begin(), indices.end(), 0);
|
|
std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
|
|
totalSize);
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
|
|
insertOp, dstVector.getType(), dstVector, srcVector,
|
|
rewriter.getI32ArrayAttr(indices));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorReductionPattern final
|
|
: public OpConversionPattern<vector::ReductionOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type resultType = typeConverter->convertType(reduceOp.getType());
|
|
if (!resultType)
|
|
return failure();
|
|
|
|
auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
|
|
if (!srcVectorType || srcVectorType.getRank() != 1)
|
|
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
|
|
|
|
// Extract all elements.
|
|
int numElements = srcVectorType.getDimSize(0);
|
|
SmallVector<Value, 4> values;
|
|
values.reserve(numElements + (adaptor.getAcc() != nullptr));
|
|
Location loc = reduceOp.getLoc();
|
|
for (int i = 0; i < numElements; ++i) {
|
|
values.push_back(rewriter.create<spirv::CompositeExtractOp>(
|
|
loc, srcVectorType.getElementType(), adaptor.getVector(),
|
|
rewriter.getI32ArrayAttr({i})));
|
|
}
|
|
if (Value acc = adaptor.getAcc())
|
|
values.push_back(acc);
|
|
|
|
// Reduce them.
|
|
Value result = values.front();
|
|
for (Value next : llvm::makeArrayRef(values).drop_front()) {
|
|
switch (reduceOp.getKind()) {
|
|
#define INT_FLOAT_CASE(kind, iop, fop) \
|
|
case vector::CombiningKind::kind: \
|
|
if (resultType.isa<IntegerType>()) { \
|
|
result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
|
|
} else { \
|
|
assert(resultType.isa<FloatType>()); \
|
|
result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
|
|
} \
|
|
break
|
|
|
|
INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
|
|
INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
|
|
|
|
case vector::CombiningKind::MINUI:
|
|
case vector::CombiningKind::MINSI:
|
|
case vector::CombiningKind::MINF:
|
|
case vector::CombiningKind::MAXUI:
|
|
case vector::CombiningKind::MAXSI:
|
|
case vector::CombiningKind::MAXF:
|
|
case vector::CombiningKind::AND:
|
|
case vector::CombiningKind::OR:
|
|
case vector::CombiningKind::XOR:
|
|
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOp(reduceOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
|
|
public:
|
|
using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type dstType = getTypeConverter()->convertType(op.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
if (dstType.isa<spirv::ScalarType>()) {
|
|
rewriter.replaceOp(op, adaptor.getInput());
|
|
} else {
|
|
auto dstVecType = dstType.cast<VectorType>();
|
|
SmallVector<Value, 4> source(dstVecType.getNumElements(),
|
|
adaptor.getInput());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
|
|
source);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorShuffleOpConvert final
|
|
: public OpConversionPattern<vector::ShuffleOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto oldResultType = shuffleOp.getVectorType();
|
|
if (!spirv::CompositeType::isValid(oldResultType))
|
|
return failure();
|
|
auto newResultType = getTypeConverter()->convertType(oldResultType);
|
|
|
|
auto oldSourceType = shuffleOp.getV1VectorType();
|
|
if (oldSourceType.getNumElements() > 1) {
|
|
SmallVector<int32_t, 4> components = llvm::to_vector<4>(
|
|
llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
|
|
return attr.cast<IntegerAttr>().getValue().getZExtValue();
|
|
}));
|
|
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
|
|
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
|
|
rewriter.getI32ArrayAttr(components));
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
|
|
SmallVector<Value, 4> newOperands;
|
|
newOperands.reserve(oldResultType.getNumElements());
|
|
for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
|
|
newOperands.push_back(oldOperands[i.getZExtValue()]);
|
|
}
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
|
|
shuffleOp, newResultType, newOperands);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
|
|
VectorExtractElementOpConvert, VectorExtractOpConvert,
|
|
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
|
|
VectorInsertElementOpConvert, VectorInsertOpConvert,
|
|
VectorReductionPattern, VectorInsertStridedSliceOpConvert,
|
|
VectorShuffleOpConvert, VectorSplatPattern>(
|
|
typeConverter, patterns.getContext());
|
|
}
|