Files
clang-p2996/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
River Riddle b54c724be0 [mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef
This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup.

Differential Revision: https://reviews.llvm.org/D110293
2021-09-24 17:51:41 +00:00

102 lines
4.0 KiB
C++

//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Math dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.
namespace {
/// Converts unary and binary standard operations to SPIR-V operations.
template <typename StdOp, typename SPIRVOp>
class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
public:
using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 2);
auto dstType = this->getTypeConverter()->convertType(operation.getType());
if (!dstType)
return failure();
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
dstType != operation.getType()) {
return operation.emitError(
"bitwidth emulation is not implemented yet on unsigned op");
}
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
adaptor.getOperands());
return success();
}
};
/// Converts math.log1p to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
/// these operations.
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
public:
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Location loc = operation.getLoc();
auto type =
this->getTypeConverter()->convertType(operation.operand().getType());
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
auto onePlus =
rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
namespace mlir {
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<Log1pOpPattern,
UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
typeConverter, patterns.getContext());
}
} // namespace mlir