Files
clang-p2996/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
Mehdi Amini 308571074c Mass update the MLIR license header to mention "Part of the LLVM project"
This is an artifact from merging MLIR into LLVM, the file headers are
now aligned with the rest of the project.
2020-01-26 03:58:30 +00:00

279 lines
11 KiB
C++

//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a TargetConfiguration for reference fixed-point math
// quantization scheme based on the FxpMathOps (plus a small category of
// extension ops that can be added from other dialects).
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/Quantizer/Support/Metadata.h"
#include "mlir/Quantizer/Support/Statistics.h"
#include "mlir/Quantizer/Support/UniformConstraints.h"
using namespace mlir;
using namespace mlir::quantizer;
using namespace mlir::fxpmath;
using namespace mlir::quant;
using namespace std::placeholders;
namespace {
struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
FxpMathTargetConfigImpl(SolverContext &context)
: FxpMathTargetConfig(context) {
Builder b(&context.getMlirContext());
IntegerType i8Type = b.getIntegerType(8);
IntegerType i16Type = b.getIntegerType(16);
IntegerType i32Type = b.getIntegerType(32);
q8 = addCandidateType(
AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max()),
CandidateQuantizedType::Scheme::UniformPerLayer);
q16 = addCandidateType(
AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max()),
CandidateQuantizedType::Scheme::UniformPerLayer);
q32ExplicitFixedPoint = addCandidateType(
AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()),
CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
// Op handlers.
addOpHandler<ConstantOp>(
std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
addOpHandler<ReturnOp>(
std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
addOpHandler<quant::StatisticsOp>(
std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
// FxpMathOps.
addOpHandler<RealAddEwOp>(
std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
addOpHandler<RealMulEwOp>(
std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
addOpHandler<RealMatMulOp>(
std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
addOpHandler<RealMatMulBiasOp>(
std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
// Require stats ops.
addRequireStatsOp<RealAddEwOp>();
addRequireStatsOp<RealSubEwOp>();
addRequireStatsOp<RealDivEwOp>();
addRequireStatsOp<RealMulEwOp>();
addRequireStatsOp<RealMatMulOp>();
addRequireStatsOp<RealMatMulBiasOp>();
}
bool isHandledType(Type t) const final {
if (t.isa<FloatType>())
return true;
return (t.isa<VectorType>() || t.isa<TensorType>()) &&
t.cast<ShapedType>().getElementType().isa<FloatType>();
}
void finalizeAnchors(CAGSlice &cag) const override {
cag.enumerateImpliedConnections(
[&](CAGAnchorNode *from, CAGAnchorNode *to) {
UniformConstraintsBuilder(cag).coupleAnchors(from, to);
});
}
void addValueIdentityOpByName(StringRef opName) override {
addOpHandlerByName(
opName,
std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
}
void handleValueIdentity(Operation *op, CAGSlice &cag) const {
assert(op->getNumResults() == 1);
if (!isHandledType(op->getResult(0).getType()))
return;
auto resultNode = cag.getResultAnchor(op, 0);
resultNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::DirectStorage);
for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
if (!isHandledType(op->getOperand(opIdx).getType()))
continue;
auto operandNode = cag.getOperandAnchor(op, opIdx);
operandNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::DirectStorage);
UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
}
}
void handleConstant(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto resultNode = cag.getResultAnchor(op, 0);
resultNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
Attribute valueAttr;
if (!matchPattern(op, m_Constant(&valueAttr))) {
return;
}
AttributeTensorStatistics stats(valueAttr);
TensorAxisStatistics layerStats;
if (!stats.get(layerStats)) {
op->emitOpError("could not compute statistics");
return;
}
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
}
void handleTerminal(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getOperand(0).getType()))
return;
auto operandNode = cag.getOperandAnchor(op, 0);
operandNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
}
void handleStats(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto argNode = cag.getOperandAnchor(op, 0);
auto resultNode = cag.getResultAnchor(op, 0);
UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
TensorAxisStatistics layerStats;
auto statsOp = cast<quant::StatisticsOp>(op);
auto layerStatsAttr = statsOp.layerStats();
layerStats.minValue =
layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble();
layerStats.maxValue =
layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble();
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
}
void handleAdd(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto resultNode = cag.getResultAnchor(op, 0);
// Add supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
// NOTE: We couple the add such that the scale/zeroPoint match between
// both args and the result. This is overly constrained in that it is
// possible to write efficient add kernels with a bit more freedom (i.e.
// zeroPoints can vary, scales can differ by a power of two, etc).
// However, fully coupled yields the simples solutions on the fast path.
// Further efficiency can be had by constraining the zeroPoint to 0, but
// there isn't a constraint for this yet (and there are tradeoffs).
UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
addRealMathOptionalConstraints(op, resultNode, cag);
}
void handleMul(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto resultNode = cag.getResultAnchor(op, 0);
// Mul supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
addRealMathOptionalConstraints(op, resultNode, cag);
}
void handleMatMul(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto resultNode = cag.getResultAnchor(op, 0);
// Mul supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
addRealMathOptionalConstraints(op, resultNode, cag);
}
void handleMatMulBias(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto bias = cag.getOperandAnchor(op, 2);
bias->getUniformMetadata().disabledCandidateTypes =
getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
auto resultNode = cag.getResultAnchor(op, 0);
UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
// Mul supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
addRealMathOptionalConstraints(op, resultNode, cag);
}
void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
CAGSlice &cag) const {
// TODO: It would be nice if these all extended some base trait instead
// of requiring name lookup.
auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
if (clampMinAttr || clampMaxAttr) {
auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
}
}
unsigned q8;
unsigned q16;
unsigned q32ExplicitFixedPoint;
};
} // anonymous namespace
std::unique_ptr<FxpMathTargetConfig>
FxpMathTargetConfig::create(SolverContext &context) {
return std::make_unique<FxpMathTargetConfigImpl>(context);
}