Files
clang-p2996/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
Mehdi Amini 308571074c Mass update the MLIR license header to mention "Part of the LLVM project"
This is an artifact from merging MLIR into LLVM, the file headers are
now aligned with the rest of the project.
2020-01-26 03:58:30 +00:00

394 lines
14 KiB
C++

//===- LowerUniformRealMath.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "UniformKernelUtils.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/FxpMathOps/Passes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::fxpmath;
using namespace mlir::fxpmath::detail;
using namespace mlir::quant;
namespace {
struct LowerUniformRealMathPass
: public FunctionPass<LowerUniformRealMathPass> {
void runOnFunction() override;
};
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
void runOnFunction() override;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Dequantize
//===----------------------------------------------------------------------===//
static Value emitUniformPerLayerDequantize(Location loc, Value input,
UniformQuantizedType elementType,
PatternRewriter &rewriter) {
// Pre-conditions.
if (!elementType.isSigned()) {
// TODO: Support unsigned storage type.
emitWarning(loc, "unimplemented: dequantize signed uniform");
return nullptr;
}
Type storageType = elementType.castToStorageType(input.getType());
Type realType = elementType.castToExpressedType(input.getType());
Type intermediateType =
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
assert(storageType && "cannot cast to storage type");
assert(realType && "cannot cast to expressed type");
// Cast to storage type.
input = rewriter.create<StorageCastOp>(loc, storageType, input);
// Promote to intermediate type.
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
// Apply zero-point offset.
if (elementType.getZeroPoint() != 0) {
Value negZeroPointConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstIntValue(intermediateType,
-elementType.getZeroPoint()));
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
}
// Convert to float.
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
// Mul by scale.
Value scaleConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstFloatValue(realType,
APFloat(elementType.getScale())));
return rewriter.create<MulFOp>(loc, input, scaleConst);
}
static Value
emitUniformPerAxisDequantize(Location loc, Value input,
UniformQuantizedPerAxisType elementType,
PatternRewriter &rewriter) {
// TODO: Support per-axis dequantize.
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
<< "unimplemented: per-axis uniform dequantization";
return nullptr;
}
static Value emitDequantize(Location loc, Value input,
PatternRewriter &rewriter) {
Type inputType = input.getType();
QuantizedType qElementType =
QuantizedType::getQuantizedElementType(inputType);
if (auto uperLayerElementType =
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
rewriter);
} else if (auto uperAxisElementType =
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
rewriter);
} else {
return nullptr;
}
}
namespace {
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const override {
Type inputType = op.arg().getType();
Type outputType = op.getResult().getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
if (expressedOutputType != outputType) {
// Not a valid uniform cast.
return matchFailure();
}
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
rewriter.replaceOp(op, dequantizedValue);
return matchSuccess();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Elementwise add
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
info.rhsType != info.resultType) {
return failure();
}
// Choose a byte aligned intermediate width big enough to perform the
// calculation without overflow.
// TODO: This should probably be made just big enough to avoid overflow and
// leave the downstream tooling to decide how to align that to machine
// word sizes.
unsigned intermediateWidth =
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Add.
Value resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Zero point offset adjustment.
// result = (lhs - zp) + (rhs - zp) + zp
// zpOffset = -zp
int zpOffset = -1 * info.resultType.getZeroPoint();
if (zpOffset != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType, zpOffset));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
//===----------------------------------------------------------------------===//
// Elementwise mul
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned()) {
return failure();
}
double outputMultiplierReal = info.lhsType.getScale() *
info.rhsType.getScale() /
info.resultType.getScale();
if (outputMultiplierReal > 1.0) {
info.op->emitWarning(
"unimplemented: cannot multiply with multiplier > 1.0");
return failure();
}
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
// avoid overflow.
unsigned intermediateWidth = 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Apply argument zeroPoints.
if (info.lhsType.getZeroPoint() != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.lhsType.getZeroPoint()));
lhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
}
if (info.rhsType.getZeroPoint() != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.rhsType.getZeroPoint()));
rhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
}
// Mul.
Value resultValue =
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Scale output.
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
resultValue = rewriter.create<RoundingDivideByPotISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
// Zero point offset adjustment.
if (info.resultType.getZeroPoint() != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType,
info.resultType.getZeroPoint()));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
namespace {
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// LowerUniformRealMath pass
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
applyPatternsGreedily(fn, patterns);
}
OpPassBase<FuncOp> *mlir::fxpmath::createLowerUniformRealMathPass() {
return new LowerUniformRealMathPass();
}
static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
"fxpmath-lower-uniform-real-math",
"Lowers uniform-quantized real math ops to integer arithmetic.");
//===----------------------------------------------------------------------===//
// LowerUniformCasts pass
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.insert<UniformDequantizePattern>(context);
applyPatternsGreedily(fn, patterns);
}
OpPassBase<FuncOp> *mlir::fxpmath::createLowerUniformCastsPass() {
return new LowerUniformCastsPass();
}
static PassRegistration<LowerUniformCastsPass>
lowerUniformCastsPass("fxpmath-lower-uniform-casts",
"Lowers uniform-quantized casts.");