This commit merges the DRR pattern for std.constant to spv.constant conversion into the C++ OpConversionPattern. This allows us to have remove the DRR pattern file. Along the way, this commit enhanced std.constant to spv.constant conversion to consider type conversions, which means converting the underlying attributes if necessary. Differential Revision: https://reviews.llvm.org/D76246
600 lines
23 KiB
C++
600 lines
23 KiB
C++
//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements patterns to convert standard ops to SPIR-V ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "std-to-spirv-pattern"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns true if the given `type` is a boolean scalar or vector type.
|
|
static bool isBoolScalarOrVector(Type type) {
|
|
if (type.isInteger(1))
|
|
return true;
|
|
if (auto vecType = type.dyn_cast<VectorType>())
|
|
return vecType.getElementType().isInteger(1);
|
|
return false;
|
|
}
|
|
|
|
/// Converts the given `srcAttr` into a boolean attribute if it holds a integral
|
|
/// value. Returns null attribute if conversion fails.
|
|
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
|
|
if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
|
|
return boolAttr;
|
|
if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
|
|
return builder.getBoolAttr(intAttr.getValue().getBoolValue());
|
|
return BoolAttr();
|
|
}
|
|
|
|
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
|
|
/// Returns null attribute if conversion fails.
|
|
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
|
|
Builder builder) {
|
|
// If the source number uses less active bits than the target bitwidth, then
|
|
// it should be safe to convert.
|
|
if (srcAttr.getValue().isIntN(dstType.getWidth()))
|
|
return builder.getIntegerAttr(dstType, srcAttr.getInt());
|
|
|
|
// XXX: Try again by interpreting the source number as a signed value.
|
|
// Although integers in the standard dialect are signless, they can represent
|
|
// a signed number. It's the operation decides how to interpret. This is
|
|
// dangerous, but it seems there is no good way of handling this if we still
|
|
// want to change the bitwidth. Emit a message at least.
|
|
if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
|
|
auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
|
|
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
|
|
<< dstAttr << "' for type '" << dstType << "'\n");
|
|
return dstAttr;
|
|
}
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
|
|
<< "' illegal: cannot fit into target type '"
|
|
<< dstType << "'\n");
|
|
return IntegerAttr();
|
|
}
|
|
|
|
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
|
|
/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
|
|
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
|
|
Builder builder) {
|
|
// Only support converting to float for now.
|
|
if (!dstType.isF32())
|
|
return FloatAttr();
|
|
|
|
// Try to convert the source floating-point number to single precision.
|
|
APFloat dstVal = srcAttr.getValue();
|
|
bool losesInfo = false;
|
|
APFloat::opStatus status =
|
|
dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
|
|
if (status != APFloat::opOK || losesInfo) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< srcAttr << " illegal: cannot fit into converted type '"
|
|
<< dstType << "'\n");
|
|
return FloatAttr();
|
|
}
|
|
|
|
return builder.getF32FloatAttr(dstVal.convertToFloat());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation conversion
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Note that DRR cannot be used for the patterns in this file: we may need to
|
|
// convert type along the way, which requires ConversionPattern. DRR generates
|
|
// normal RewritePattern.
|
|
|
|
namespace {
|
|
|
|
/// Converts binary standard operations to SPIR-V operations.
|
|
template <typename StdOp, typename SPIRVOp>
|
|
class BinaryOpPattern final : public SPIRVOpLowering<StdOp> {
|
|
public:
|
|
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 2);
|
|
auto dstType = this->typeConverter.convertType(operation.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
|
|
ArrayRef<NamedAttribute>());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts bitwise standard operations to SPIR-V operations. This is a special
|
|
/// pattern other than the BinaryOpPatternPattern because if the operands are
|
|
/// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
|
|
/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
|
|
template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
|
|
class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
|
|
public:
|
|
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 2);
|
|
auto dstType =
|
|
this->typeConverter.convertType(operation.getResult().getType());
|
|
if (!dstType)
|
|
return failure();
|
|
if (isBoolScalarOrVector(operands.front().getType())) {
|
|
rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
|
|
operation, dstType, operands, ArrayRef<NamedAttribute>());
|
|
} else {
|
|
rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
|
|
operation, dstType, operands, ArrayRef<NamedAttribute>());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts composite std.constant operation to spv.constant.
|
|
class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
|
|
public:
|
|
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts scalar std.constant operation to spv.constant.
|
|
class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
|
|
public:
|
|
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts floating-point comparison operations to SPIR-V ops.
|
|
class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
|
|
public:
|
|
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts integer compare operation to SPIR-V ops.
|
|
class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
|
|
public:
|
|
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.load to spv.Load.
|
|
class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
|
|
public:
|
|
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.return to spv.Return.
|
|
class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
|
|
public:
|
|
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.select to spv.Select.
|
|
class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
|
|
public:
|
|
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
|
|
LogicalResult
|
|
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.store to spv.Store.
|
|
class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
|
|
public:
|
|
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts type-casting standard operations to SPIR-V operations.
|
|
template <typename StdOp, typename SPIRVOp>
|
|
class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
|
|
public:
|
|
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 1);
|
|
auto dstType =
|
|
this->typeConverter.convertType(operation.getResult().getType());
|
|
if (dstType == operands.front().getType()) {
|
|
// Due to type conversion, we are seeing the same source and target type.
|
|
// Then we can just erase this operation by forwarding its operand.
|
|
rewriter.replaceOp(operation, operands.front());
|
|
} else {
|
|
rewriter.template replaceOpWithNewOp<SPIRVOp>(
|
|
operation, dstType, operands, ArrayRef<NamedAttribute>());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts std.xor to SPIR-V operations.
|
|
class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
|
|
public:
|
|
using SPIRVOpLowering<XOrOp>::SPIRVOpLowering;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp with composite type.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
|
|
ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto srcType = constOp.getType().dyn_cast<ShapedType>();
|
|
if (!srcType)
|
|
return failure();
|
|
|
|
// std.constant should only have vector or tenor types.
|
|
assert(srcType.isa<VectorType>() || srcType.isa<RankedTensorType>());
|
|
|
|
auto dstType = typeConverter.convertType(srcType);
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
|
|
ShapedType dstAttrType = dstElementsAttr.getType();
|
|
if (!dstElementsAttr)
|
|
return failure();
|
|
|
|
// If the composite type has more than one dimensions, perform linearization.
|
|
if (srcType.getRank() > 1) {
|
|
if (srcType.isa<RankedTensorType>()) {
|
|
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
|
|
srcType.getElementType());
|
|
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
|
|
} else {
|
|
// TODO(antiagainst): add support for large vectors.
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
Type srcElemType = srcType.getElementType();
|
|
Type dstElemType;
|
|
// Tensor types are converted to SPIR-V array types; vector types are
|
|
// converted to SPIR-V vector/array types.
|
|
if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
|
|
dstElemType = arrayType.getElementType();
|
|
else
|
|
dstElemType = dstType.cast<VectorType>().getElementType();
|
|
|
|
// If the source and destination element types are different, perform
|
|
// attribute conversion.
|
|
if (srcElemType != dstElemType) {
|
|
SmallVector<Attribute, 8> elements;
|
|
if (srcElemType.isa<FloatType>()) {
|
|
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
|
|
FloatAttr dstAttr = convertFloatAttr(
|
|
srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
elements.push_back(dstAttr);
|
|
}
|
|
} else if (srcElemType.isInteger(1)) {
|
|
return failure();
|
|
} else {
|
|
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
|
|
IntegerAttr dstAttr =
|
|
convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
|
|
dstElemType.cast<IntegerType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
elements.push_back(dstAttr);
|
|
}
|
|
}
|
|
|
|
// Unfortunately, we cannot use dialect-specific types for element
|
|
// attributes; element attributes only works with standard types. So we need
|
|
// to prepare another converted standard types for the destination elements
|
|
// attribute.
|
|
if (dstAttrType.isa<RankedTensorType>())
|
|
dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
|
|
else
|
|
dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
|
|
|
|
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
|
|
dstElementsAttr);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp with scalar type.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ConstantScalarOpPattern::matchAndRewrite(
|
|
ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Type srcType = constOp.getType();
|
|
if (!srcType.isIntOrIndexOrFloat())
|
|
return failure();
|
|
|
|
Type dstType = typeConverter.convertType(srcType);
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
// Floating-point types.
|
|
if (srcType.isa<FloatType>()) {
|
|
auto srcAttr = constOp.value().cast<FloatAttr>();
|
|
auto dstAttr = srcAttr;
|
|
|
|
// Floating-point types not supported in the target environment are all
|
|
// converted to float type.
|
|
if (srcType != dstType) {
|
|
dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
|
|
return success();
|
|
}
|
|
|
|
// Bool type.
|
|
if (srcType.isInteger(1)) {
|
|
// std.constant can use 0/1 instead of true/false for i1 values. We need to
|
|
// handle that here.
|
|
auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
|
|
return success();
|
|
}
|
|
|
|
// IndexType or IntegerType. Index values are converted to 32-bit integer
|
|
// values when converting to SPIR-V.
|
|
auto srcAttr = constOp.value().cast<IntegerAttr>();
|
|
auto dstAttr =
|
|
convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
CmpFOpOperandAdaptor cmpFOpOperands(operands);
|
|
|
|
switch (cmpFOp.getPredicate()) {
|
|
#define DISPATCH(cmpPredicate, spirvOp) \
|
|
case cmpPredicate: \
|
|
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
|
|
cmpFOpOperands.lhs(), \
|
|
cmpFOpOperands.rhs()); \
|
|
return success();
|
|
|
|
// Ordered.
|
|
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
|
|
DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
|
|
DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
|
|
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
|
|
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
|
|
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
|
|
// Unordered.
|
|
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
|
|
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
|
|
DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
|
|
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
|
|
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
|
|
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
|
|
|
|
#undef DISPATCH
|
|
|
|
default:
|
|
break;
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
CmpIOpOperandAdaptor cmpIOpOperands(operands);
|
|
|
|
switch (cmpIOp.getPredicate()) {
|
|
#define DISPATCH(cmpPredicate, spirvOp) \
|
|
case cmpPredicate: \
|
|
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
|
|
cmpIOpOperands.lhs(), \
|
|
cmpIOpOperands.rhs()); \
|
|
return success();
|
|
|
|
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
|
|
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
|
|
DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
|
|
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
|
|
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
|
|
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
|
|
DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
|
|
DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
|
|
DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
|
|
DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
|
|
|
|
#undef DISPATCH
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LoadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
LoadOpOperandAdaptor loadOperands(operands);
|
|
auto loadPtr = spirv::getElementPtr(
|
|
typeConverter, loadOp.memref().getType().cast<MemRefType>(),
|
|
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
|
|
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReturnOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
if (returnOp.getNumOperands()) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SelectOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
SelectOpOperandAdaptor selectOperands(operands);
|
|
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
|
|
selectOperands.true_value(),
|
|
selectOperands.false_value());
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StoreOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
StoreOpOperandAdaptor storeOperands(operands);
|
|
auto storePtr = spirv::getElementPtr(
|
|
typeConverter, storeOp.memref().getType().cast<MemRefType>(),
|
|
storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(),
|
|
rewriter);
|
|
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
|
|
storeOperands.value());
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XorOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
assert(operands.size() == 2);
|
|
|
|
if (isBoolScalarOrVector(operands.front().getType()))
|
|
return failure();
|
|
|
|
auto dstType = typeConverter.convertType(xorOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands,
|
|
ArrayRef<NamedAttribute>());
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern population
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|
SPIRVTypeConverter &typeConverter,
|
|
OwningRewritePatternList &patterns) {
|
|
patterns.insert<
|
|
BinaryOpPattern<AddFOp, spirv::FAddOp>,
|
|
BinaryOpPattern<AddIOp, spirv::IAddOp>,
|
|
BinaryOpPattern<DivFOp, spirv::FDivOp>,
|
|
BinaryOpPattern<MulFOp, spirv::FMulOp>,
|
|
BinaryOpPattern<MulIOp, spirv::IMulOp>,
|
|
BinaryOpPattern<RemFOp, spirv::FRemOp>,
|
|
BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
|
|
BinaryOpPattern<SignedShiftRightOp, spirv::ShiftRightArithmeticOp>,
|
|
BinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
|
|
BinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
|
|
BinaryOpPattern<SubFOp, spirv::FSubOp>,
|
|
BinaryOpPattern<SubIOp, spirv::ISubOp>,
|
|
BinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
|
|
BinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
|
|
BinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
|
|
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
|
|
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
|
|
ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,
|
|
CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern,
|
|
StoreOpPattern, TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
|
|
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
|
|
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
|
|
context, typeConverter);
|
|
}
|
|
} // namespace mlir
|