Files
clang-p2996/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
River Riddle a4c3a6455c Move the emitError/Warning/Remark utility methods out of MLIRContext and into the mlir namespace.
Now that Locations are attributes, they have direct access to the MLIR context. This allows for simplifying error emission by removing unnecessary context lookups.

PiperOrigin-RevId: 255112791
2019-06-25 21:32:23 -07:00

403 lines
15 KiB
C++

//===- LowerUniformRealMath.cpp ------------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "UniformKernelUtils.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/FxpMathOps/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
using namespace mlir;
using namespace mlir::fxpmath;
using namespace mlir::fxpmath::detail;
using namespace mlir::quant;
namespace {
struct LowerUniformRealMathPass
: public FunctionPass<LowerUniformRealMathPass> {
void runOnFunction() override;
};
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
void runOnFunction() override;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Dequantize
//===----------------------------------------------------------------------===//
static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
UniformQuantizedType elementType,
PatternRewriter &rewriter) {
// Pre-conditions.
if (!elementType.isSigned()) {
// TODO: Support unsigned storage type.
emitWarning(loc, "unimplemented: dequantize signed uniform");
return nullptr;
}
Type storageType = elementType.castToStorageType(input->getType());
Type realType = elementType.castToExpressedType(input->getType());
Type intermediateType =
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
assert(storageType && "cannot cast to storage type");
assert(realType && "cannot cast to expressed type");
// Cast to storage type.
input = rewriter.create<StorageCastOp>(loc, storageType, input);
// Promote to intermediate type.
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
// Apply zero-point offset.
if (elementType.getZeroPoint() != 0) {
Value *negZeroPointConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstIntValue(intermediateType,
-elementType.getZeroPoint()));
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
}
// Convert to float.
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
// Mul by scale.
Value *scaleConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstFloatValue(realType,
APFloat(elementType.getScale())));
return rewriter.create<MulFOp>(loc, input, scaleConst);
}
static Value *
emitUniformPerAxisDequantize(Location loc, Value *input,
UniformQuantizedPerAxisType elementType,
PatternRewriter &rewriter) {
// TODO: Support per-axis dequantize.
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
<< "unimplemented: per-axis uniform dequantization";
return nullptr;
}
static Value *emitDequantize(Location loc, Value *input,
PatternRewriter &rewriter) {
Type inputType = input->getType();
QuantizedType qElementType =
QuantizedType::getQuantizedElementType(inputType);
if (auto uperLayerElementType =
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
rewriter);
} else if (auto uperAxisElementType =
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
rewriter);
} else {
return nullptr;
}
}
namespace {
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const {
Type inputType = op.arg()->getType();
Type outputType = op.getResult()->getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
if (expressedOutputType != outputType) {
// Not a valid uniform cast.
return matchFailure();
}
Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
rewriter.replaceOp(op, dequantizedValue);
return matchSuccess();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Elementwise add
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
info.rhsType != info.resultType) {
return failure();
}
// Choose a byte aligned intermediate width big enough to perform the
// calculation without overflow.
// TODO: This should probably be made just big enough to avoid overflow and
// leave the downstream tooling to decide how to align that to machine
// word sizes.
unsigned intermediateWidth =
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value *lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value *rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Add.
Value *resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Zero point offset adjustment.
// result = (lhs - zp) + (rhs - zp) + zp
// zpOffset = -zp
int zpOffset = -1 * info.resultType.getZeroPoint();
if (zpOffset != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType, zpOffset));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
//===----------------------------------------------------------------------===//
// Elementwise mul
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned()) {
return failure();
}
double outputMultiplierReal = info.lhsType.getScale() *
info.rhsType.getScale() /
info.resultType.getScale();
if (outputMultiplierReal > 1.0) {
info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
return failure();
}
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
// avoid overflow.
unsigned intermediateWidth = 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value *lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value *rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Apply argument zeroPoints.
if (info.lhsType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.lhsType.getZeroPoint()));
lhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
}
if (info.rhsType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.rhsType.getZeroPoint()));
rhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
}
// Mul.
Value *resultValue =
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Scale output.
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
resultValue = rewriter.create<RoundingDivideByPotISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
// Zero point offset adjustment.
if (info.resultType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType,
info.resultType.getZeroPoint()));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
namespace {
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// LowerUniformRealMath pass
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
auto &fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
applyPatternsGreedily(fn, std::move(patterns));
}
FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
return new LowerUniformRealMathPass();
}
static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
"fxpmath-lower-uniform-real-math",
"Lowers uniform-quantized real math ops to integer arithmetic.");
//===----------------------------------------------------------------------===//
// LowerUniformCasts pass
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
auto &fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
applyPatternsGreedily(fn, std::move(patterns));
}
FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
return new LowerUniformCastsPass();
}
static PassRegistration<LowerUniformCastsPass>
lowerUniformCastsPass("fxpmath-lower-uniform-casts",
"Lowers uniform-quantized casts.");