Files
clang-p2996/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
River Riddle b54c724be0 [mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef
This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup.

Differential Revision: https://reviews.llvm.org/D110293
2021-09-24 17:51:41 +00:00

251 lines
9.0 KiB
C++

//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Vector dialect to SPIRV dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "../PassDetail.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include <numeric>
using namespace mlir;
/// Gets the first integer value from `attr`, assuming it is an integer array
/// attribute.
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
namespace {
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
if (!dstType)
return failure();
if (dstType == adaptor.source().getType())
rewriter.replaceOp(bitcastOp, adaptor.source());
else
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
adaptor.source());
return success();
}
};
struct VectorBroadcastConvert final
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (broadcastOp.source().getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
return failure();
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
adaptor.source());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
broadcastOp, broadcastOp.getVectorType(), source);
return success();
}
};
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only support extracting a scalar value now.
VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
if (resultVectorType && resultVectorType.getNumElements() > 1)
return failure();
auto dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
rewriter.replaceOp(extractOp, adaptor.vector());
return success();
}
int32_t id = getFirstIntValue(extractOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.vector(), id);
return success();
}
};
struct VectorExtractStridedSliceOpConvert final
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
uint64_t offset = getFirstIntValue(extractOp.offsets());
uint64_t size = getFirstIntValue(extractOp.sizes());
uint64_t stride = getFirstIntValue(extractOp.strides());
if (stride != 1)
return failure();
Value srcVector = adaptor.getOperands().front();
// Extract vector<1xT> case.
if (dstType.isa<spirv::ScalarType>()) {
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
srcVector, offset);
return success();
}
SmallVector<int32_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), offset);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
return failure();
rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
return success();
}
};
struct VectorInsertOpConvert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
int32_t id = getFirstIntValue(insertOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.source(), adaptor.dest(), id);
return success();
}
};
struct VectorExtractElementOpConvert final
: public OpConversionPattern<vector::ExtractElementOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
return failure();
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
extractElementOp, extractElementOp.getType(), adaptor.vector(),
extractElementOp.position());
return success();
}
};
struct VectorInsertElementOpConvert final
: public OpConversionPattern<vector::InsertElementOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
return failure();
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
adaptor.source(), insertElementOp.position());
return success();
}
};
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
// Insert scalar values not supported yet.
if (srcVector.getType().isa<spirv::ScalarType>() ||
dstVector.getType().isa<spirv::ScalarType>())
return failure();
uint64_t stride = getFirstIntValue(insertOp.strides());
if (stride != 1)
return failure();
uint64_t totalSize =
dstVector.getType().cast<VectorType>().getNumElements();
uint64_t insertSize =
srcVector.getType().cast<VectorType>().getNumElements();
uint64_t offset = getFirstIntValue(insertOp.offsets());
SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0);
std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
totalSize);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
insertOp, dstVector.getType(), dstVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
} // namespace
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
VectorInsertStridedSliceOpConvert>(typeConverter,
patterns.getContext());
}