Files
clang-p2996/mlir/lib/Dialect/Index/IR/IndexOps.cpp
2022-12-07 20:17:39 -08:00

463 lines
16 KiB
C++

//===- IndexOps.cpp - Index operation definitions --------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SmallString.h"
using namespace mlir;
using namespace mlir::index;
//===----------------------------------------------------------------------===//
// IndexDialect
//===----------------------------------------------------------------------===//
void IndexDialect::registerOperations() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
>();
}
Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
Type type, Location loc) {
// Materialize bool constants as `i1`.
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
if (!type.isSignlessInteger(1))
return nullptr;
return b.create<BoolConstantOp>(loc, type, boolValue);
}
// Materialize integer attributes as `index`.
if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
if (!indexValue.getType().isa<IndexType>() || !type.isa<IndexType>())
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
return b.create<ConstantOp>(loc, indexValue);
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// Fold Utilities
//===----------------------------------------------------------------------===//
/// Fold an index operation irrespective of the target bitwidth. The
/// operation must satisfy the property:
///
/// ```
/// trunc(f(a, b)) = f(trunc(a), trunc(b))
/// ```
///
/// For all values of `a` and `b`. The function accepts a lambda that computes
/// the integer result, which in turn must satisfy the above property.
static OpFoldResult foldBinaryOpUnchecked(
ArrayRef<Attribute> operands,
function_ref<Optional<APInt>(const APInt &, const APInt &)> calculate) {
assert(operands.size() == 2 && "binary operation expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
if (!lhs || !rhs)
return {};
Optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
if (!result)
return {};
assert(result->trunc(32) ==
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
}
/// Fold an index operation only if the truncated 64-bit result matches the
/// 32-bit result for operations that don't satisfy the above property. These
/// are operations where the upper bits of the operands can affect the lower
/// bits of the results.
///
/// The function accepts a lambda that computes the integer result in both
/// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
/// not folded.
static OpFoldResult foldBinaryOpChecked(
ArrayRef<Attribute> operands,
function_ref<Optional<APInt>(const APInt &, const APInt &lhs)> calculate) {
assert(operands.size() == 2 && "binary operation expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
// Only fold index operands.
if (!lhs || !rhs)
return {};
// Compute the 64-bit result and the 32-bit result.
Optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
if (!result64)
return {};
Optional<APInt> result32 =
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
if (!result32)
return {};
// Compare the truncated 64-bit result to the 32-bit result.
if (result64->trunc(32) != *result32)
return {};
// The operation can be folded for these particular operands.
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
}
//===----------------------------------------------------------------------===//
// DivSOp
//===----------------------------------------------------------------------===//
OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold division by zero.
if (rhs.isZero())
return std::nullopt;
return lhs.sdiv(rhs);
});
}
//===----------------------------------------------------------------------===//
// DivUOp
//===----------------------------------------------------------------------===//
OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold division by zero.
if (rhs.isZero())
return std::nullopt;
return lhs.udiv(rhs);
});
}
//===----------------------------------------------------------------------===//
// CeilDivSOp
//===----------------------------------------------------------------------===//
/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
// Don't fold division by zero.
if (m.isZero())
return std::nullopt;
// Short-circuit the zero case.
if (n.isZero())
return n;
bool mGtZ = m.sgt(0);
if (n.sgt(0) != mGtZ) {
// If the operands have different signs, compute the negative result. Signed
// division overflow is not possible, since if `m == -1`, `n` can be at most
// `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
return -(-n).sdiv(m);
}
// Otherwise, compute the positive result. Signed division overflow is not
// possible since if `m == -1`, `x` will be `1`.
int64_t x = mGtZ ? -1 : 1;
return (n + x).sdiv(m) + 1;
}
OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, calculateCeilDivS);
}
//===----------------------------------------------------------------------===//
// CeilDivUOp
//===----------------------------------------------------------------------===//
OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
return foldBinaryOpChecked(
operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
// Don't fold division by zero.
if (m.isZero())
return std::nullopt;
// Short-circuit the zero case.
if (n.isZero())
return n;
return (n - 1).udiv(m) + 1;
});
}
//===----------------------------------------------------------------------===//
// FloorDivSOp
//===----------------------------------------------------------------------===//
/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
// Don't fold division by zero.
if (m.isZero())
return std::nullopt;
// Short-circuit the zero case.
if (n.isZero())
return n;
bool mLtZ = m.slt(0);
if (n.slt(0) == mLtZ) {
// If the operands have the same sign, compute the positive result.
return n.sdiv(m);
}
// If the operands have different signs, compute the negative result. Signed
// division overflow is not possible since if `m == -1`, `x` will be 1 and
// `n` can be at most `INT_MAX`.
int64_t x = mLtZ ? 1 : -1;
return -1 - (x - n).sdiv(m);
}
OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, calculateFloorDivS);
}
//===----------------------------------------------------------------------===//
// RemSOp
//===----------------------------------------------------------------------===//
OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.srem(rhs);
});
}
//===----------------------------------------------------------------------===//
// RemUOp
//===----------------------------------------------------------------------===//
OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.urem(rhs);
});
}
//===----------------------------------------------------------------------===//
// MaxSOp
//===----------------------------------------------------------------------===//
OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.sgt(rhs) ? lhs : rhs;
});
}
//===----------------------------------------------------------------------===//
// MaxUOp
//===----------------------------------------------------------------------===//
OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.ugt(rhs) ? lhs : rhs;
});
}
//===----------------------------------------------------------------------===//
// ShlOp
//===----------------------------------------------------------------------===//
OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// We cannot fold if the RHS is greater than or equal to 32 because
// this would be UB in 32-bit systems but not on 64-bit systems. RHS is
// already treated as unsigned.
if (rhs.uge(32))
return {};
return lhs << rhs;
});
}
//===----------------------------------------------------------------------===//
// ShrSOp
//===----------------------------------------------------------------------===//
OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold if RHS is greater than or equal to 32.
if (rhs.uge(32))
return {};
return lhs.ashr(rhs);
});
}
//===----------------------------------------------------------------------===//
// ShrUOp
//===----------------------------------------------------------------------===//
OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold if RHS is greater than or equal to 32.
if (rhs.uge(32))
return {};
return lhs.lshr(rhs);
});
}
//===----------------------------------------------------------------------===//
// AndOp
//===----------------------------------------------------------------------===//
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
}
//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
}
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
}
//===----------------------------------------------------------------------===//
// CastSOp
//===----------------------------------------------------------------------===//
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
}
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
}
//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//
/// Compare two integers according to the comparison predicate.
bool compareIndices(const APInt &lhs, const APInt &rhs,
IndexCmpPredicate pred) {
switch (pred) {
case IndexCmpPredicate::EQ:
return lhs.eq(rhs);
case IndexCmpPredicate::NE:
return lhs.ne(rhs);
case IndexCmpPredicate::SGE:
return lhs.sge(rhs);
case IndexCmpPredicate::SGT:
return lhs.sgt(rhs);
case IndexCmpPredicate::SLE:
return lhs.sle(rhs);
case IndexCmpPredicate::SLT:
return lhs.slt(rhs);
case IndexCmpPredicate::UGE:
return lhs.uge(rhs);
case IndexCmpPredicate::UGT:
return lhs.ugt(rhs);
case IndexCmpPredicate::ULE:
return lhs.ule(rhs);
case IndexCmpPredicate::ULT:
return lhs.ult(rhs);
}
llvm_unreachable("unhandled IndexCmpPredicate predicate");
}
OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "compare expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
if (!lhs || !rhs)
return {};
// Perform the comparison in 64-bit and 32-bit.
bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
bool result32 = compareIndices(lhs.getValue().trunc(32),
rhs.getValue().trunc(32), getPred());
if (result64 != result32)
return {};
return BoolAttr::get(getContext(), result64);
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << "idx" << getValueAttr().getValue();
setNameFn(getResult(), specialName.str());
}
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}
void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
build(b, state, b.getIndexType(), b.getIndexAttr(value));
}
//===----------------------------------------------------------------------===//
// BoolConstantOp
//===----------------------------------------------------------------------===//
OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}
void BoolConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), getValue() ? "true" : "false");
}
//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"