Files
clang-p2996/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
Sandeep Dasgupta 81d7eef134 Sub-channel quantized type implementation (#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization
in
MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694).

In order to make the review process easier, the PR has been divided into
the following commit labels:

1. **Add implementation for sub-channel type:** Includes the class
design for `UniformQuantizedSubChannelType`, printer/parser and bytecode
read/write support. The existing types (per-tensor and per-axis) are
unaltered.
2. **Add implementation for sub-channel type:** Lowering of
`quant.qcast` and `quant.dcast` operations to Linalg operations.
3. **Adding C/Python Apis:** We first define he C-APIs and build the
Python-APIs on top of those.
4. **Add pass to normalize generic ....:** This pass normalizes
sub-channel quantized types to per-tensor per-axis types, if possible.


A  design note:
- **Explicitly storing the `quantized_dimensions`, even when they can be
derived for ranked tensor.**
While it's possible to infer quantized dimensions from the static shape
of the scales (or zero-points) tensor for ranked
data tensors
([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3)
for background), there are cases where this can lead to ambiguity and
issues with round-tripping.

```
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>
```

The shape of the scales tensor is [1, 2], which might suggest that only
axis 1 is quantized. While this inference is technically correct, as the
block size for axis 0 is a degenerate case (equal to the dimension
size), it can cause problems with round-tripping. Therefore, even for
ranked tensors, we are explicitly storing the quantized dimensions.
Suggestions welcome!


PS: I understand that the upcoming holidays may impact your schedule, so
please take your time with the review. There's no rush.
2025-03-23 07:37:55 -05:00

71 lines
2.2 KiB
C++

//===- QuantDialectBytecode.cpp - Quant Bytecode Implementation
//------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "QuantDialectBytecode.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::quant;
namespace {
static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader,
double &val) {
auto valOr =
reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble());
if (failed(valOr))
return failure();
val = valOr->convertToDouble();
return success();
}
#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc"
/// This class implements the bytecode interface for the Quant dialect.
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
QuantDialectBytecodeInterface(Dialect *dialect)
: BytecodeDialectInterface(dialect) {}
//===--------------------------------------------------------------------===//
// Attributes
Attribute readAttribute(DialectBytecodeReader &reader) const override {
return ::readAttribute(getContext(), reader);
}
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const override {
return ::writeAttribute(attr, writer);
}
//===--------------------------------------------------------------------===//
// Types
Type readType(DialectBytecodeReader &reader) const override {
return ::readType(getContext(), reader);
}
LogicalResult writeType(Type type,
DialectBytecodeWriter &writer) const override {
return ::writeType(type, writer);
}
};
} // namespace
void quant::detail::addBytecodeInterface(QuantDialect *dialect) {
dialect->addInterfaces<QuantDialectBytecodeInterface>();
}