This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694). In order to make the review process easier, the PR has been divided into the following commit labels: 1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered. 2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those. 4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible. A design note: - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping. ``` Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>> ``` The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome! PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.
71 lines
2.2 KiB
C++
71 lines
2.2 KiB
C++
//===- QuantDialectBytecode.cpp - Quant Bytecode Implementation
|
|
//------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "QuantDialectBytecode.h"
|
|
#include "mlir/Bytecode/BytecodeImplementation.h"
|
|
#include "mlir/Dialect/Quant/IR/Quant.h"
|
|
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::quant;
|
|
|
|
namespace {
|
|
|
|
static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader,
|
|
double &val) {
|
|
auto valOr =
|
|
reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble());
|
|
if (failed(valOr))
|
|
return failure();
|
|
val = valOr->convertToDouble();
|
|
return success();
|
|
}
|
|
|
|
#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc"
|
|
|
|
/// This class implements the bytecode interface for the Quant dialect.
|
|
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
|
|
QuantDialectBytecodeInterface(Dialect *dialect)
|
|
: BytecodeDialectInterface(dialect) {}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Attributes
|
|
|
|
Attribute readAttribute(DialectBytecodeReader &reader) const override {
|
|
return ::readAttribute(getContext(), reader);
|
|
}
|
|
|
|
LogicalResult writeAttribute(Attribute attr,
|
|
DialectBytecodeWriter &writer) const override {
|
|
return ::writeAttribute(attr, writer);
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Types
|
|
|
|
Type readType(DialectBytecodeReader &reader) const override {
|
|
return ::readType(getContext(), reader);
|
|
}
|
|
|
|
LogicalResult writeType(Type type,
|
|
DialectBytecodeWriter &writer) const override {
|
|
return ::writeType(type, writer);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void quant::detail::addBytecodeInterface(QuantDialect *dialect) {
|
|
dialect->addInterfaces<QuantDialectBytecodeInterface>();
|
|
}
|