This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup. Differential Revision: https://reviews.llvm.org/D110293
251 lines
9.0 KiB
C++
251 lines
9.0 KiB
C++
//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements patterns to convert Vector dialect to SPIRV dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
|
|
/// Gets the first integer value from `attr`, assuming it is an integer array
|
|
/// attribute.
|
|
static uint64_t getFirstIntValue(ArrayAttr attr) {
|
|
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct VectorBitcastConvert final
|
|
: public OpConversionPattern<vector::BitCastOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
if (dstType == adaptor.source().getType())
|
|
rewriter.replaceOp(bitcastOp, adaptor.source());
|
|
else
|
|
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
|
|
adaptor.source());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorBroadcastConvert final
|
|
: public OpConversionPattern<vector::BroadcastOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (broadcastOp.source().getType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
|
|
return failure();
|
|
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
|
|
adaptor.source());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
|
|
broadcastOp, broadcastOp.getVectorType(), source);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractOpConvert final
|
|
: public OpConversionPattern<vector::ExtractOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Only support extracting a scalar value now.
|
|
VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
|
|
if (resultVectorType && resultVectorType.getNumElements() > 1)
|
|
return failure();
|
|
|
|
auto dstType = getTypeConverter()->convertType(extractOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
|
|
rewriter.replaceOp(extractOp, adaptor.vector());
|
|
return success();
|
|
}
|
|
|
|
int32_t id = getFirstIntValue(extractOp.position());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
|
|
extractOp, adaptor.vector(), id);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractStridedSliceOpConvert final
|
|
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto dstType = getTypeConverter()->convertType(extractOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
|
|
uint64_t offset = getFirstIntValue(extractOp.offsets());
|
|
uint64_t size = getFirstIntValue(extractOp.sizes());
|
|
uint64_t stride = getFirstIntValue(extractOp.strides());
|
|
if (stride != 1)
|
|
return failure();
|
|
|
|
Value srcVector = adaptor.getOperands().front();
|
|
|
|
// Extract vector<1xT> case.
|
|
if (dstType.isa<spirv::ScalarType>()) {
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
|
|
srcVector, offset);
|
|
return success();
|
|
}
|
|
|
|
SmallVector<int32_t, 2> indices(size);
|
|
std::iota(indices.begin(), indices.end(), offset);
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
|
|
extractOp, dstType, srcVector, srcVector,
|
|
rewriter.getI32ArrayAttr(indices));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
|
|
fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertOpConvert final
|
|
: public OpConversionPattern<vector::InsertOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (insertOp.getSourceType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
|
|
return failure();
|
|
int32_t id = getFirstIntValue(insertOp.position());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
|
|
insertOp, adaptor.source(), adaptor.dest(), id);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractElementOpConvert final
|
|
: public OpConversionPattern<vector::ExtractElementOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
|
|
extractElementOp, extractElementOp.getType(), adaptor.vector(),
|
|
extractElementOp.position());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertElementOpConvert final
|
|
: public OpConversionPattern<vector::InsertElementOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
|
|
insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
|
|
adaptor.source(), insertElementOp.position());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertStridedSliceOpConvert final
|
|
: public OpConversionPattern<vector::InsertStridedSliceOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value srcVector = adaptor.getOperands().front();
|
|
Value dstVector = adaptor.getOperands().back();
|
|
|
|
// Insert scalar values not supported yet.
|
|
if (srcVector.getType().isa<spirv::ScalarType>() ||
|
|
dstVector.getType().isa<spirv::ScalarType>())
|
|
return failure();
|
|
|
|
uint64_t stride = getFirstIntValue(insertOp.strides());
|
|
if (stride != 1)
|
|
return failure();
|
|
|
|
uint64_t totalSize =
|
|
dstVector.getType().cast<VectorType>().getNumElements();
|
|
uint64_t insertSize =
|
|
srcVector.getType().cast<VectorType>().getNumElements();
|
|
uint64_t offset = getFirstIntValue(insertOp.offsets());
|
|
|
|
SmallVector<int32_t, 2> indices(totalSize);
|
|
std::iota(indices.begin(), indices.end(), 0);
|
|
std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
|
|
totalSize);
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
|
|
insertOp, dstVector.getType(), dstVector, srcVector,
|
|
rewriter.getI32ArrayAttr(indices));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
|
|
VectorExtractElementOpConvert, VectorExtractOpConvert,
|
|
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
|
|
VectorInsertElementOpConvert, VectorInsertOpConvert,
|
|
VectorInsertStridedSliceOpConvert>(typeConverter,
|
|
patterns.getContext());
|
|
}
|