Files
clang-p2996/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
Mehdi Amini 308571074c Mass update the MLIR license header to mention "Part of the LLVM project"
This is an artifact from merging MLIR into LLVM, the file headers are
now aligned with the rest of the project.
2020-01-26 03:58:30 +00:00

113 lines
3.9 KiB
C++

//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/Passes.h"
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
#include "mlir/Dialect/QuantOps/UniformSupport.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::quant;
namespace {
class ConvertConstPass : public FunctionPass<ConvertConstPass> {
public:
void runOnFunction() override;
};
struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
PatternMatchResult
QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const {
Attribute value;
// Is the operand a constant?
if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
return matchFailure();
}
// Does the qbarrier convert to a quantized type. This will not be true
// if a quantized type has not yet been chosen or if the cast to an equivalent
// storage type is not supported.
Type qbarrierResultType = qbarrier.getResult().getType();
QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
if (!quantizedElementType) {
return matchFailure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
return matchFailure();
}
// Is the operand type compatible with the expressed type of the quantized
// type? This will not be true if the qbarrier is superfluous (converts
// from and to a quantized type).
if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg().getType())) {
return matchFailure();
}
// Is the constant value a type expressed in a way that we support?
if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
!value.isa<SparseElementsAttr>()) {
return matchFailure();
}
Type newConstValueType;
auto newConstValue =
quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
return matchFailure();
}
// When creating the new const op, use a fused location that combines the
// original const and the qbarrier that led to the quantization.
auto fusedLoc = FusedLoc::get(
{qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()},
rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
qbarrier.getType(), newConstOp);
return matchSuccess();
}
void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
patterns.insert<QuantizedConstRewrite>(context);
applyPatternsGreedily(func, patterns);
}
std::unique_ptr<OpPassBase<FuncOp>> mlir::quant::createConvertConstPass() {
return std::make_unique<ConvertConstPass>();
}
static PassRegistration<ConvertConstPass>
pass("quant-convert-const",
"Converts constants followed by qbarrier to actual quantized values");