//===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Interfaces/ViewLikeInterface.h" using namespace mlir; //===----------------------------------------------------------------------===// // ViewLike Interfaces //===----------------------------------------------------------------------===// /// Include the definitions of the loop-like interfaces. #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" LogicalResult mlir::verifyListOfOperandsOrIntegers( Operation *op, StringRef name, unsigned numElements, ArrayAttr attr, ValueRange values, llvm::function_ref isDynamic) { /// Check static and dynamic offsets/sizes/strides does not overflow type. if (attr.size() != numElements) return op->emitError("expected ") << numElements << " " << name << " values"; unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { return isDynamic(attr.cast().getInt()); }); if (values.size() != expectedNumDynamicEntries) return op->emitError("expected ") << expectedNumDynamicEntries << " dynamic " << name << " values"; return success(); } LogicalResult mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { std::array maxRanks = op.getArrayAttrMaxRanks(); // Offsets can come in 2 flavors: // 1. Either single entry (when maxRanks == 1). // 2. Or as an array whose rank must match that of the mixed sizes. // So that the result type is well-formed. if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT op.getMixedOffsets().size() != op.getMixedSizes().size()) return op->emitError( "expected mixed offsets rank to match mixed sizes rank (") << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() << ") so the rank of the result type is well-formed."; // Ranks of mixed sizes and strides must always match so the result type is // well-formed. if (op.getMixedSizes().size() != op.getMixedStrides().size()) return op->emitError( "expected mixed sizes rank to match mixed strides rank (") << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() << ") so the rank of the result type is well-formed."; if (failed(verifyListOfOperandsOrIntegers( op, "offset", maxRanks[0], op.static_offsets(), op.offsets(), ShapedType::isDynamicStrideOrOffset))) return failure(); if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], op.static_sizes(), op.sizes(), ShapedType::isDynamic))) return failure(); if (failed(verifyListOfOperandsOrIntegers( op, "stride", maxRanks[2], op.static_strides(), op.strides(), ShapedType::isDynamicStrideOrOffset))) return failure(); return success(); } void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayAttr integers, int64_t dynVal) { printer << '['; if (integers.empty()) { printer << "]"; return; } unsigned idx = 0; llvm::interleaveComma(integers, printer, [&](Attribute a) { int64_t val = a.cast().getInt(); if (val == dynVal) printer << values[idx++]; else printer << val; }); printer << ']'; } ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, ArrayAttr &integers, int64_t dynVal) { if (failed(parser.parseLSquare())) return failure(); // 0-D. if (succeeded(parser.parseOptionalRSquare())) { integers = parser.getBuilder().getArrayAttr({}); return success(); } SmallVector attrVals; while (true) { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); attrVals.push_back(dynVal); } else { IntegerAttr attr; if (failed(parser.parseAttribute(attr))) return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; attrVals.push_back(attr.getInt()); } if (succeeded(parser.parseOptionalComma())) continue; if (failed(parser.parseRSquare())) return failure(); break; } integers = parser.getBuilder().getI64ArrayAttr(attrVals); return success(); } bool mlir::detail::sameOffsetsSizesAndStrides( OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref cmp) { if (a.static_offsets().size() != b.static_offsets().size()) return false; if (a.static_sizes().size() != b.static_sizes().size()) return false; if (a.static_strides().size() != b.static_strides().size()) return false; for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) if (!cmp(std::get<0>(it), std::get<1>(it))) return false; for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) if (!cmp(std::get<0>(it), std::get<1>(it))) return false; for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) if (!cmp(std::get<0>(it), std::get<1>(it))) return false; return true; } SmallVector mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues, const int64_t dynamicValueIndicator) { SmallVector res; res.reserve(staticValues.size()); unsigned numDynamic = 0; unsigned count = static_cast(staticValues.size()); for (unsigned idx = 0; idx < count; ++idx) { APInt value = staticValues[idx].cast().getValue(); res.push_back(value.getSExtValue() == dynamicValueIndicator ? OpFoldResult{dynamicValues[numDynamic++]} : OpFoldResult{staticValues[idx]}); } return res; } SmallVector mlir::getMixedStridesOrOffsets(ArrayAttr staticValues, ValueRange dynamicValues) { return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamicStrideOrOffset); } SmallVector mlir::getMixedSizes(ArrayAttr staticValues, ValueRange dynamicValues) { return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamicSize); } std::pair> mlir::decomposeMixedValues(Builder &b, const SmallVectorImpl &mixedValues, const int64_t dynamicValueIndicator) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { if (it.is()) { staticValues.push_back(it.get().cast().getInt()); } else { staticValues.push_back(dynamicValueIndicator); dynamicValues.push_back(it.get()); } } return {b.getI64ArrayAttr(staticValues), dynamicValues}; } std::pair> mlir::decomposeMixedStridesOrOffsets( OpBuilder &b, const SmallVectorImpl &mixedValues) { return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicStrideOrOffset); } std::pair> mlir::decomposeMixedSizes(OpBuilder &b, const SmallVectorImpl &mixedValues) { return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize); }