Full revamp of the 'quant' dialect. This is an implementation for the RFC at https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
215 lines
7.6 KiB
C++
215 lines
7.6 KiB
C++
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "QuantDialectBytecode.h"
|
|
#include "TypeDetail.h"
|
|
|
|
#include "mlir/Dialect/Quant/IR/Quant.h"
|
|
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
|
|
|
|
|
|
namespace mlir {
|
|
namespace quant {
|
|
|
|
namespace {
|
|
|
|
// Verify the integrity of per-axis quantization information, if present.
|
|
//
|
|
// - quantizedType
|
|
// Any quantized type. Any quantized type with no per-axis quantization is
|
|
// ignored.
|
|
//
|
|
// - containerType
|
|
// Original input or result type of the operation using the provided quantized
|
|
// type. Used to ensure that the quantized type appears within a tensor and
|
|
// that the tensor is compatible with per-axis quantization information.
|
|
//
|
|
LogicalResult verifyPerAxisQuantization(Operation *op,
|
|
QuantizedType quantizedType,
|
|
Type containerType) {
|
|
auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
|
|
if (!quantizedPerAxisType)
|
|
return success();
|
|
|
|
auto tensorType = dyn_cast<TensorType>(containerType);
|
|
if (!tensorType)
|
|
return op->emitError("scalar types may not use per-axis quantization");
|
|
|
|
if (!tensorType.hasRank())
|
|
return success();
|
|
|
|
int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
|
|
if (quantizedDimension >= tensorType.getRank())
|
|
return op->emitError("quantized dimension must be less than tensor rank");
|
|
|
|
int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
|
|
if (quantizedDimensionSize != ShapedType::kDynamic &&
|
|
quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
|
|
return op->emitError(
|
|
"quantized dimension size does not match number of scales");
|
|
|
|
return success();
|
|
}
|
|
|
|
// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
|
|
//
|
|
// - quantizedType
|
|
// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
|
|
// whether as a primitive type or in a tensor.
|
|
//
|
|
// - floatType
|
|
// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
|
|
// whether as a primitive type or in a tensor.
|
|
//
|
|
// - containerType
|
|
// Type of original input or result.
|
|
//
|
|
LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
|
|
FloatType floatType, Type containerType) {
|
|
if (quantizedType.getExpressedType() != floatType)
|
|
return op->emitError(
|
|
"expressed type in quantized type expected to match float type");
|
|
|
|
// Veriy integrity of per-axis quantization information, if present.
|
|
return verifyPerAxisQuantization(op, quantizedType, containerType);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void QuantDialect::initialize() {
|
|
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
|
|
UniformQuantizedPerAxisType>();
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
|
|
>();
|
|
detail::addBytecodeInterface(this);
|
|
}
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DequantizeCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult DequantizeCastOp::verify() {
|
|
return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
|
|
getInput().getType());
|
|
}
|
|
|
|
OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
|
|
// Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
|
|
// with the value of x. Values x and y are guaranteed to be of the same type
|
|
// in this pattern.
|
|
auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
|
|
if (!srcQcastOp)
|
|
return {};
|
|
assert(srcQcastOp.getInput().getType() == getType());
|
|
return srcQcastOp.getInput();
|
|
}
|
|
|
|
FloatType DequantizeCastOp::getFloatType() {
|
|
return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
|
|
}
|
|
|
|
QuantizedType DequantizeCastOp::getQuantizedType() {
|
|
return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
|
|
}
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// QuantizeCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult QuantizeCastOp::verify() {
|
|
return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
|
|
getInput().getType());
|
|
}
|
|
|
|
OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
|
|
// Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
|
|
// with the value of x if the casts invert each other. Contrary to the folding
|
|
// pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
|
|
// x and y are not guaranteed to be of the same type here, as they may use
|
|
// different quantization parameters.
|
|
auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
|
|
if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
|
|
return {};
|
|
return srcDcastOp.getInput();
|
|
}
|
|
|
|
FloatType QuantizeCastOp::getFloatType() {
|
|
return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
|
|
}
|
|
|
|
QuantizedType QuantizeCastOp::getQuantizedType() {
|
|
return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
|
|
}
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StorageCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult StorageCastOp::verify() {
|
|
auto quantizedType = getQuantizedType();
|
|
auto integerType = getIntegerType();
|
|
if (quantizedType.getStorageType() != integerType)
|
|
return emitError(
|
|
"storage type in quantized type expected to match integer type");
|
|
|
|
// Verify integrity of per-axis quantization information, if available. While
|
|
// the quantization type may appear in the input or the result, their tensor
|
|
// shapes are guaranteed to be identical at this point.
|
|
return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
|
|
}
|
|
|
|
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
|
|
// Matches x -> quant.scast -> quant.scast -> y, replacing the second
|
|
// quant.scast with the value of x if the casts invert each other.
|
|
auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
|
|
if (!srcScastOp || srcScastOp.getInput().getType() != getType())
|
|
return {};
|
|
return srcScastOp.getInput();
|
|
}
|
|
|
|
IntegerType StorageCastOp::getIntegerType() {
|
|
auto inputScalarType = getElementTypeOrSelf(getInput().getType());
|
|
if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
|
|
return integerType;
|
|
|
|
auto resultScalarType = getElementTypeOrSelf(getResult().getType());
|
|
return cast<IntegerType>(resultScalarType);
|
|
}
|
|
|
|
QuantizedType StorageCastOp::getQuantizedType() {
|
|
auto inputScalarType = getElementTypeOrSelf(getInput().getType());
|
|
if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
|
|
return quantizedType;
|
|
|
|
auto resultScalarType = getElementTypeOrSelf(getResult().getType());
|
|
return cast<QuantizedType>(resultScalarType);
|
|
}
|
|
|
|
|
|
} // namespace quant
|
|
} // namespace mlir
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
|
|
|