Files
clang-p2996/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Hongren Zheng d48777ece5 [mlir][polynomial] remove incorrect canonicalization rule (#110318)
arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert https://github.com/llvm/llvm-project/pull/93132
Addresses https://github.com/google/heir/issues/749

Cc @j2kun
2024-09-28 08:37:57 -07:00

299 lines
11 KiB
C++

//===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/APInt.h"
using namespace mlir;
using namespace mlir::polynomial;
void FromTensorOp::build(OpBuilder &builder, OperationState &result,
Value input, RingAttr ring) {
TensorType tensorType = dyn_cast<TensorType>(input.getType());
auto bitWidth = tensorType.getElementTypeBitWidth();
APInt cmod(1 + bitWidth, 1);
cmod = cmod << bitWidth;
Type resultType = PolynomialType::get(builder.getContext(), ring);
build(builder, result, resultType, input);
}
LogicalResult FromTensorOp::verify() {
ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
RingAttr ring = getOutput().getType().getRing();
IntPolynomialAttr polyMod = ring.getPolynomialModulus();
if (polyMod) {
unsigned polyDegree = polyMod.getPolynomial().getDegree();
bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
if (!compatible) {
InFlightDiagnostic diag = emitOpError()
<< "input type " << getInput().getType()
<< " does not match output type "
<< getOutput().getType();
diag.attachNote()
<< "the input type must be a tensor of shape [d] where d "
"is at most the degree of the polynomialModulus of "
"the output type's ring attribute";
return diag;
}
}
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
InFlightDiagnostic diag = emitOpError()
<< "input tensor element type "
<< getInput().getType().getElementType()
<< " is too large to fit in the coefficients of "
<< getOutput().getType();
diag.attachNote() << "the input tensor's elements must be rescaled"
" to fit before using from_tensor";
return diag;
}
return success();
}
LogicalResult ToTensorOp::verify() {
ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
IntPolynomialAttr polyMod =
getInput().getType().getRing().getPolynomialModulus();
if (polyMod) {
unsigned polyDegree = polyMod.getPolynomial().getDegree();
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
if (compatible)
return success();
InFlightDiagnostic diag = emitOpError()
<< "input type " << getInput().getType()
<< " does not match output type "
<< getOutput().getType();
diag.attachNote()
<< "the output type must be a tensor of shape [d] where d "
"is at most the degree of the polynomialModulus of "
"the input type's ring attribute";
return diag;
}
return success();
}
LogicalResult MulScalarOp::verify() {
Type argType = getPolynomial().getType();
PolynomialType polyType;
if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
polyType = cast<PolynomialType>(shapedPolyType.getElementType());
} else {
polyType = cast<PolynomialType>(argType);
}
Type coefficientType = polyType.getRing().getCoefficientType();
if (coefficientType != getScalar().getType())
return emitOpError() << "polynomial coefficient type " << coefficientType
<< " does not match scalar type "
<< getScalar().getType();
return success();
}
/// Test if a value is a primitive nth root of unity modulo cmod.
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
const APInt &cmod) {
// The first or subsequent multiplications, may overflow the input bit width,
// so scale them up to ensure they do not overflow.
unsigned requiredBitWidth =
std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
assert(r.ule(cmodExt) && "root must be less than cmod");
uint64_t upperBound = n.getZExtValue();
APInt a = r;
for (size_t k = 1; k < upperBound; k++) {
if (a.isOne())
return false;
a = (a * r).urem(cmodExt);
}
return a.isOne();
}
/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType,
std::optional<PrimitiveRootAttr> root) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
<< "expects a ring encoding to be provided to the tensor";
}
auto encodedRing = dyn_cast<RingAttr>(encoding);
if (!encodedRing) {
return op->emitOpError()
<< "the provided tensor encoding is not a ring attribute";
}
if (encodedRing != ring) {
return op->emitOpError()
<< "encoded ring type " << encodedRing
<< " is not equivalent to the polynomial ring " << ring;
}
unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
ArrayRef<int64_t> tensorShape = tensorType.getShape();
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
if (!compatible) {
InFlightDiagnostic diag = op->emitOpError()
<< "tensor type " << tensorType
<< " does not match output type " << ring;
diag.attachNote() << "the tensor must have shape [d] where d "
"is exactly the degree of the polynomialModulus of "
"the polynomial type's ring attribute";
return diag;
}
if (root.has_value()) {
APInt rootValue = root.value().getValue().getValue();
APInt rootDegree = root.value().getDegree().getValue();
APInt cmod = ring.getCoefficientModulus().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
return op->emitOpError()
<< "provided root " << rootValue.getZExtValue()
<< " is not a primitive root "
<< "of unity mod " << cmod.getZExtValue()
<< ", with the specified degree " << rootDegree.getZExtValue();
}
}
return success();
}
LogicalResult NTTOp::verify() {
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
getOutput().getType(), getRoot());
}
LogicalResult INTTOp::verify() {
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
getInput().getType(), getRoot());
}
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
// Using the built-in parser.parseAttribute requires the full
// #polynomial.typed_int_polynomial syntax, which is excessive.
// Instead we parse a keyword int to signal it's an integer polynomial
Type type;
if (succeeded(parser.parseOptionalKeyword("float"))) {
Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
if (floatPolyAttr) {
if (parser.parseColon() || parser.parseType(type))
return failure();
result.addAttribute("value",
TypedFloatPolynomialAttr::get(type, floatPolyAttr));
result.addTypes(type);
return success();
}
}
if (succeeded(parser.parseOptionalKeyword("int"))) {
Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
if (intPolyAttr) {
if (parser.parseColon() || parser.parseType(type))
return failure();
result.addAttribute("value",
TypedIntPolynomialAttr::get(type, intPolyAttr));
result.addTypes(type);
return success();
}
}
// In the worst case, still accept the verbose versions.
TypedIntPolynomialAttr typedIntPolyAttr;
OptionalParseResult res =
parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
typedIntPolyAttr, "value", result.attributes);
if (res.has_value() && succeeded(res.value())) {
result.addTypes(typedIntPolyAttr.getType());
return success();
}
TypedFloatPolynomialAttr typedFloatPolyAttr;
res = parser.parseAttribute<TypedFloatPolynomialAttr>(
typedFloatPolyAttr, "value", result.attributes);
if (res.has_value() && succeeded(res.value())) {
result.addTypes(typedFloatPolyAttr.getType());
return success();
}
return failure();
}
void ConstantOp::print(OpAsmPrinter &p) {
p << " ";
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
p << "int";
intPoly.getValue().print(p);
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
p << "float";
floatPoly.getValue().print(p);
} else {
assert(false && "unexpected attribute type");
}
p << " : ";
p.printType(getOutput().getType());
}
LogicalResult ConstantOp::inferReturnTypes(
MLIRContext *context, std::optional<mlir::Location> location,
ConstantOp::Adaptor adaptor,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
Attribute operand = adaptor.getValue();
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
inferredReturnTypes.push_back(intPoly.getType());
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
inferredReturnTypes.push_back(floatPoly.getType());
} else {
assert(false && "unexpected attribute type");
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd canonicalization patterns
//===----------------------------------------------------------------------===//
namespace {
#include "PolynomialCanonicalization.inc"
} // namespace
void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SubAsAdd>(context);
}
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<NTTAfterINTT>(context);
}
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<INTTAfterNTT>(context);
}