Files
clang-p2996/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
River Riddle 1b97cdf885 [mlir][IR][NFC] Move context/location parameters of builtin Type::get methods to the start of the parameter list
This better matches the rest of the infrastructure, is much simpler, and makes it easier to move these types to being declaratively specified.

Differential Revision: https://reviews.llvm.org/D93432
2020-12-17 13:01:36 -08:00

88 lines
3.3 KiB
C++

//===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/IR/BuiltinTypes.h"
#include <numeric>
using namespace mlir;
using namespace mlir::quant;
static bool isQuantizablePrimitiveType(Type inputType) {
return inputType.isa<FloatType>();
}
const ExpressedToQuantizedConverter
ExpressedToQuantizedConverter::forInputType(Type inputType) {
if (inputType.isa<TensorType, VectorType>()) {
Type elementType = inputType.cast<ShapedType>().getElementType();
if (!isQuantizablePrimitiveType(elementType))
return ExpressedToQuantizedConverter{inputType, nullptr};
return ExpressedToQuantizedConverter{inputType, elementType};
}
// Supported primitive type (which just is the expressed type).
if (isQuantizablePrimitiveType(inputType))
return ExpressedToQuantizedConverter{inputType, inputType};
// Unsupported.
return ExpressedToQuantizedConverter{inputType, nullptr};
}
Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
assert(expressedType && "convert() on unsupported conversion");
if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), elementalType);
if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
return UnrankedTensorType::get(elementalType);
if (auto vectorType = inputType.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), elementalType);
// If the expressed types match, just use the new elemental type.
if (elementalType.getExpressedType() == expressedType)
return elementalType;
// Unsupported.
return nullptr;
}
ElementsAttr
UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) {
if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
return convert(attr);
}
// TODO: handles sparse elements attribute
return nullptr;
}
DenseElementsAttr
UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
// Creates the converter for each chunk. Normally the size of the
// quantization dim is 3, so we can cache all the converters.
ShapedType type = attr.getType();
size_t dimSize = type.getDimSize(quantizationDim);
if (dimSize != scales.size()) {
return {};
}
SmallVector<UniformQuantizedValueConverter, 4> converters;
converters.reserve(dimSize);
for (int i = 0, e = dimSize; i != e; ++i) {
converters.push_back(getPerChunkConverter(i));
}
// Scan the elements of the dense elements attributes and quantize them by
// using the right quantization parameters.
int64_t flattenIndex = 0;
auto shape = type.getShape();
int64_t chunkSize =
std::accumulate(std::next(shape.begin(), quantizationDim + 1),
shape.end(), 1, std::multiplies<int64_t>());
Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
return attr.mapValues(newElementType, [&](const APFloat &old) {
int chunkIndex = (flattenIndex++) / chunkSize;
return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
});
}