Files
clang-p2996/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
Hongren Zheng 4425dfba6a [mlir][polynomial] Add and verify constraints of coefficientModulus for ringAttr (#111016)
Currently the semantic of coefficientModulus is unclear and a lowering
of it faces uncertainty, for example,
https://github.com/google/heir/pull/995#issuecomment-2387394895

Also, it lacks a verifier which should conform to the definition in the
document.

This PR tries to further define the semantic of coefficientModulus and
adds a verifier for it.

Cc @j2kun for review and suggestions.
2024-10-05 08:42:59 -07:00

237 lines
7.8 KiB
C++

//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
namespace mlir {
namespace polynomial {
void IntPolynomialAttr::print(AsmPrinter &p) const {
p << '<' << getPolynomial() << '>';
}
void FloatPolynomialAttr::print(AsmPrinter &p) const {
p << '<' << getPolynomial() << '>';
}
/// A callable that parses the coefficient using the appropriate method for the
/// given monomial type, and stores the parsed coefficient value on the
/// monomial.
template <typename MonomialType>
using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
/// Try to parse a monomial. If successful, populate the fields of the outparam
/// `monomial` with the results, and the `variable` outparam with the parsed
/// variable name. Sets shouldParseMore to true if the monomial is followed by
/// a '+'.
///
template <typename Monomial>
ParseResult
parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
bool &isConstantTerm, bool &shouldParseMore,
ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
isConstantTerm = false;
shouldParseMore = false;
// A + indicates it's a constant term with more to go, as in `1 + x`.
if (succeeded(parser.parseOptionalPlus())) {
// If no coefficient was parsed, and there's a +, then it's effectively
// parsing an empty string.
if (!parsedCoeffResult.has_value()) {
return failure();
}
monomial.setExponent(APInt(apintBitWidth, 0));
isConstantTerm = true;
shouldParseMore = true;
return success();
}
// A monomial can be a trailing constant term, as in `x + 1`.
if (failed(parser.parseOptionalKeyword(&variable))) {
// If neither a coefficient nor a variable was found, then it's effectively
// parsing an empty string.
if (!parsedCoeffResult.has_value()) {
return failure();
}
monomial.setExponent(APInt(apintBitWidth, 0));
isConstantTerm = true;
return success();
}
// Parse exponentiation symbol as `**`. We can't use caret because it's
// reserved for basic block identifiers If no star is present, it's treated
// as a polynomial with exponent 1.
if (succeeded(parser.parseOptionalStar())) {
// If there's one * there must be two.
if (failed(parser.parseStar())) {
return failure();
}
// If there's a **, then the integer exponent is required.
APInt parsedExponent(apintBitWidth, 0);
if (failed(parser.parseInteger(parsedExponent))) {
parser.emitError(parser.getCurrentLocation(),
"found invalid integer exponent");
return failure();
}
monomial.setExponent(parsedExponent);
} else {
monomial.setExponent(APInt(apintBitWidth, 1));
}
if (succeeded(parser.parseOptionalPlus())) {
shouldParseMore = true;
}
return success();
}
template <typename Monomial>
LogicalResult
parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
llvm::StringSet<> &variables,
ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
while (true) {
Monomial parsedMonomial;
llvm::StringRef parsedVariableRef;
bool isConstantTerm;
bool shouldParseMore;
if (failed(parseMonomial<Monomial>(
parser, parsedMonomial, parsedVariableRef, isConstantTerm,
shouldParseMore, parseAndStoreCoefficient))) {
parser.emitError(parser.getCurrentLocation(), "expected a monomial");
return failure();
}
if (!isConstantTerm) {
std::string parsedVariable = parsedVariableRef.str();
variables.insert(parsedVariable);
}
monomials.push_back(parsedMonomial);
if (shouldParseMore)
continue;
if (succeeded(parser.parseOptionalGreater())) {
break;
}
parser.emitError(
parser.getCurrentLocation(),
"expected + and more monomials, or > to end polynomial attribute");
return failure();
}
if (variables.size() > 1) {
std::string vars = llvm::join(variables.keys(), ", ");
parser.emitError(
parser.getCurrentLocation(),
"polynomials must have one indeterminate, but there were multiple: " +
vars);
return failure();
}
return success();
}
Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
llvm::SmallVector<IntMonomial> monomials;
llvm::StringSet<> variables;
if (failed(parsePolynomialAttr<IntMonomial>(
parser, monomials, variables,
[&](IntMonomial &monomial) -> OptionalParseResult {
APInt parsedCoeff(apintBitWidth, 1);
OptionalParseResult result =
parser.parseOptionalInteger(parsedCoeff);
monomial.setCoefficient(parsedCoeff);
return result;
}))) {
return {};
}
auto result = IntPolynomial::fromMonomials(monomials);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation())
<< "parsed polynomial must have unique exponents among monomials";
return {};
}
return IntPolynomialAttr::get(parser.getContext(), result.value());
}
Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
llvm::SmallVector<FloatMonomial> monomials;
llvm::StringSet<> variables;
ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient =
[&](FloatMonomial &monomial) -> OptionalParseResult {
double coeffValue = 1.0;
ParseResult result = parser.parseFloat(coeffValue);
monomial.setCoefficient(APFloat(coeffValue));
return OptionalParseResult(result);
};
if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables,
parseAndStoreCoefficient))) {
return {};
}
auto result = FloatPolynomial::fromMonomials(monomials);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation())
<< "parsed polynomial must have unique exponents among monomials";
return {};
}
return FloatPolynomialAttr::get(parser.getContext(), result.value());
}
LogicalResult
RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
Type coefficientType, IntegerAttr coefficientModulus,
IntPolynomialAttr polynomialModulus) {
if (coefficientModulus) {
auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
if (!coeffIntType) {
return emitError() << "coefficientModulus specified but coefficientType "
"is not integral";
}
APInt coeffModValue = coefficientModulus.getValue();
if (coeffModValue == 0) {
return emitError() << "coefficientModulus should not be 0";
}
if (coeffModValue.slt(0)) {
return emitError() << "coefficientModulus should be positive";
}
auto coeffModWidth = (coeffModValue - 1).getActiveBits();
auto coeffWidth = coeffIntType.getWidth();
if (coeffModWidth > coeffWidth) {
return emitError() << "coefficientModulus needs bit width of "
<< coeffModWidth
<< " but coefficientType can only contain "
<< coeffWidth << " bits";
}
}
return success();
}
} // namespace polynomial
} // namespace mlir