Files
clang-p2996/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
Mahesh Ravishankar a78bd84cf8 NFC: Refactor Dialect Conversion targeting SPIR-V.
Refactoring the conversion from StandardOps/GPU dialect to SPIR-V
dialect:
1) Move the SPIRVTypeConversion and SPIRVOpLowering class into SPIR-V
   dialect.
2) Add header files that expose functions to add patterns for the
   dialects to SPIR-V lowering, as well as a pass that does the
   dialect to SPIR-V lowering.
3) Make SPIRVOpLowering derive from OpLowering class.
PiperOrigin-RevId: 280486871
2019-11-14 12:34:54 -08:00

242 lines
9.6 KiB
C++

//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
/// Convert constant operation with IndexType return to SPIR-V constant
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
/// attribute are consistent.
class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
return matchFailure();
}
// The attribute has index type which is not directly supported in
// SPIR-V. Get the integer value and create a new IntegerAttr.
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
if (!constAttr) {
return matchFailure();
}
// Use the bitwidth set in the value attribute to decide the result type
// of the SPIR-V constant operation since SPIR-V does not support index
// types.
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
return matchFailure();
}
auto spirvConstType =
typeConverter.convertBasicType(constIndexOp.getResult()->getType());
auto spirvConstVal =
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
spirvConstVal);
return matchSuccess();
}
};
/// Convert compare operation to SPIR-V dialect.
class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>( \
cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
return matchSuccess();
DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
#undef DISPATCH
default:
break;
}
return matchFailure();
}
};
/// Convert integer binary operations to SPIR-V operations. Cannot use
/// tablegen for this. If the integer operation is on variables of IndexType,
/// the type of the return value of the replacement operation differs from
/// that of the replaced operation. This is not handled in tablegen-based
/// pattern specification.
template <typename StdOp, typename SPIRVOp>
class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
public:
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
this->typeConverter.convertBasicType(operation.getResult()->getType());
rewriter.template replaceOpWithNewOp<SPIRVOp>(
operation, resultType, operands, ArrayRef<NamedAttribute>());
return this->matchSuccess();
}
};
/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
/// IndexType while that of the replacement operation are of type i32. This is
/// not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.load.
class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
public:
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
LoadOpOperandAdaptor loadOperands(operands);
auto basePtr = loadOperands.memref();
auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
if (!ptrType) {
return matchFailure();
}
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
loadOp.getLoc(), basePtr, loadOperands.indices());
auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
rewriter.replaceOpWithNewOp<spirv::LoadOp>(
loadOp, loadPtrType.getPointeeType(), loadPtr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
}
};
/// Convert return -> spv.Return.
class ReturnToSPIRVConversion final : public SPIRVOpLowering<ReturnOp> {
public:
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
if (returnOp.getNumOperands()) {
return matchFailure();
}
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return matchSuccess();
}
};
/// Convert select -> spv.Select
class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
public:
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
return matchSuccess();
}
};
/// Convert store -> spv.StoreOp. The operands of the replaced operation are
/// of IndexType while that of the replacement operation are of type i32. This
/// is not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.store.
class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
public:
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
StoreOpOperandAdaptor storeOperands(operands);
auto value = storeOperands.value();
auto basePtr = storeOperands.memref();
auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
if (!ptrType) {
return matchFailure();
}
auto storePtr = rewriter.create<spirv::AccessChainOp>(
storeOp.getLoc(), basePtr, storeOperands.indices());
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, value,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
}
};
} // namespace
namespace {
/// Import the Standard Ops to SPIR-V Patterns.
#include "StandardToSPIRV.cpp.inc"
} // namespace
namespace mlir {
void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
patterns
.insert<ConstantIndexOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<DivISOp, spirv::SDivOp>,
IntegerOpConversion<RemISOp, spirv::SModOp>,
IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
context, typeConverter);
}
} // namespace mlir