This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
2752 lines
106 KiB
C++
2752 lines
106 KiB
C++
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the Linalg operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
|
|
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
|
|
int64_t dim) {
|
|
auto type = cast<ShapedType>(v.getType());
|
|
if (!type.isDynamicDim(dim))
|
|
return builder.getIndexAttr(type.getDimSize(dim));
|
|
|
|
return getAsOpFoldResult(
|
|
TypeSwitch<Type, Value>(v.getType())
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
|
return builder.create<tensor::DimOp>(loc, v, dim);
|
|
})
|
|
.Case<MemRefType>([&](MemRefType t) -> Value {
|
|
return builder.create<memref::DimOp>(loc, v, dim);
|
|
}));
|
|
}
|
|
|
|
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
|
|
/// `source`.
|
|
static Value getSlice(OpBuilder &b, Location loc, Value source,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes,
|
|
ArrayRef<OpFoldResult> strides) {
|
|
return TypeSwitch<Type, Value>(source.getType())
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
|
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
|
|
strides);
|
|
})
|
|
.Case<MemRefType>([&](MemRefType type) -> Value {
|
|
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
|
|
strides);
|
|
})
|
|
.Default([&](Type t) { return nullptr; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|
int64_t dim) {
|
|
if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
|
|
return b.createOrFold<memref::DimOp>(loc, source, dim);
|
|
if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
|
|
return b.createOrFold<tensor::DimOp>(loc, source, dim);
|
|
llvm_unreachable("Expected MemRefType or TensorType");
|
|
}
|
|
|
|
OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
|
|
int64_t dim) {
|
|
auto shapedType = llvm::cast<ShapedType>(source.getType());
|
|
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
|
|
return createOrFoldDimOp(b, loc, source, dim);
|
|
return b.getIndexAttr(shapedType.getDimSize(dim));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Support for named Linalg ops defined in ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
|
|
ArrayRef<NamedAttribute>)>;
|
|
|
|
/// Fills the region of a structured operation using the provided
|
|
/// `regionBuilder`. The method is used by both named structured ops created by
|
|
/// ods-gen and by manually defined C++ ops. It is called by both builders and
|
|
/// parsers and creates a block with arguments corresponding to the elemental
|
|
/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
|
|
/// ShapedType.
|
|
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder) {
|
|
assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
|
|
|
|
SmallVector<Type, 8> argTypes;
|
|
SmallVector<Location, 8> argLocs;
|
|
for (auto containers : {inputTypes, outputTypes}) {
|
|
for (auto t : containers) {
|
|
argTypes.push_back(
|
|
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
|
|
|
|
// TODO: Pass in a proper location here.
|
|
argLocs.push_back(opBuilder.getUnknownLoc());
|
|
}
|
|
}
|
|
|
|
// RAII.
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
Block *body =
|
|
opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
regionBuilder(b, *body, attrs);
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
// iterator_types is an auto-generated method.
|
|
}
|
|
|
|
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
|
|
/// The result types are derived automatically if `resultTensorTypes` is none.
|
|
/// The body of the operation is filled using `regionBuilder`. All ods-gen
|
|
/// created structured operations use the method to implement their builders.
|
|
static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder) {
|
|
// Derive the result types if needed.
|
|
SmallVector<Type> derivedResultTypes =
|
|
resultTensorTypes.value_or(TypeRange());
|
|
if (!resultTensorTypes)
|
|
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
|
|
llvm::IsaPred<RankedTensorType>);
|
|
|
|
state.addOperands(inputs);
|
|
state.addOperands(outputs);
|
|
state.addTypes(derivedResultTypes);
|
|
state.addAttributes(attributes);
|
|
state.addAttribute(
|
|
"operandSegmentSizes",
|
|
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
|
|
// Create and fill the region of the structured operation.
|
|
Region ®ion = *state.addRegion();
|
|
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
|
|
state.attributes.getAttrs(), regionBuilder);
|
|
}
|
|
|
|
/// Common parsing used for both named structured ops created by ods-gen and by
|
|
/// manually defined C++ ops. Does not handle regions.
|
|
static ParseResult
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
SmallVectorImpl<Type> &outputTypes,
|
|
bool addOperandSegmentSizes = true) {
|
|
SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
|
|
outputsOperands;
|
|
|
|
if (succeeded(parser.parseOptionalLess())) {
|
|
if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
|
|
return failure();
|
|
}
|
|
attrsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
inputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(inputsOperands) ||
|
|
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
|
outputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
|
|
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
if (addOperandSegmentSizes) {
|
|
// This is a bit complex because we're trying to be backward compatible with
|
|
// operation syntax that mix the inherent attributes and the discardable
|
|
// ones in the same dictionary. If the properties are used, we append the
|
|
// operandSegmentSizes there directly. Otherwise we append it to the
|
|
// discardable attributes dictionary where it is handled by the generic
|
|
// Operation::create(...) method.
|
|
if (result.propertiesAttr) {
|
|
NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
|
|
attrs.append("operandSegmentSizes",
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
result.propertiesAttr = attrs.getDictionary(parser.getContext());
|
|
} else {
|
|
result.addAttribute("operandSegmentSizes",
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
}
|
|
}
|
|
if (!result.propertiesAttr) {
|
|
std::optional<RegisteredOperationName> info =
|
|
result.name.getRegisteredInfo();
|
|
if (info) {
|
|
if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
|
|
return parser.emitError(attrsLoc)
|
|
<< "'" << result.name.getStringRef() << "' op ";
|
|
})))
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
if (!inputs.empty())
|
|
p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
|
|
if (!outputs.empty())
|
|
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specific parsing and printing for named structured ops created by ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseNamedStructuredOpRegion(
|
|
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
|
|
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder) {
|
|
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
|
|
return parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
"region expects {0} args, got {1}",
|
|
numRegionArgs, inputTypes.size() + outputTypes.size()));
|
|
}
|
|
|
|
OpBuilder opBuilder(parser.getContext());
|
|
fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
|
|
regionBuilder);
|
|
return success();
|
|
}
|
|
|
|
static ParseResult
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
SmallVectorImpl<Type> &resultTypes) {
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
OperationState &result,
|
|
unsigned numRegionArgs,
|
|
RegionBuilderFn regionBuilder) {
|
|
// TODO: Enable when ods-gen supports captures.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// TODO: consider merging results parsing into region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
|
|
outputTypes, result.attributes.getAttrs(),
|
|
regionBuilder))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
TypeRange resultTypes) {
|
|
if (resultTypes.empty())
|
|
return;
|
|
p.printOptionalArrowTypeList(resultTypes);
|
|
}
|
|
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
|
|
ValueRange inputs, ValueRange outputs) {
|
|
p.printOptionalAttrDict(
|
|
op->getAttrs(),
|
|
/*elidedAttrs=*/{"operandSegmentSizes",
|
|
// See generated code in
|
|
// LinalgNamedStructuredOps.yamlgen.cpp.inc
|
|
"linalg.memoized_indexing_maps"});
|
|
|
|
// Printing is shared with generic ops, except for the region and
|
|
// attributes.
|
|
printCommonStructuredOpParts(p, inputs, outputs);
|
|
|
|
// Results printing.
|
|
printNamedStructuredOpResults(p, op->getResultTypes());
|
|
|
|
// Region is elided.
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Region builder helper.
|
|
// TODO: Move this to a utility library.
|
|
// The public methods on this class are referenced directly from generated code.
|
|
// Helper build the unary, binary, and type conversion functions defined by the
|
|
// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
|
|
// class.
|
|
//
|
|
// Implementations of the math functions must be polymorphic over numeric types,
|
|
// internally performing necessary casts. If the function application makes no
|
|
// sense, then the only recourse is to assert and return nullptr. This can be
|
|
// extended later if it becomes possible to fail construction of the region. The
|
|
// invariant should be enforced at a higher level.
|
|
//
|
|
// TODO: These helpers are currently type polymorphic over the class of integer
|
|
// and floating point types, but they will not internally cast within bit
|
|
// widths of a class (mixed precision such as i8->i32) or across classes
|
|
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
|
|
// to be handled with care and work is being considered to extend the op
|
|
// language to make such cases explicit. In the mean-time, violating this will
|
|
// fail verification, which is deemed acceptable.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
class RegionBuilderHelper {
|
|
public:
|
|
RegionBuilderHelper(OpBuilder &builder, Block &block)
|
|
: builder(builder), block(block) {}
|
|
|
|
// Build the unary functions defined by OpDSL.
|
|
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
|
|
if (!isFloatingPoint(arg))
|
|
llvm_unreachable("unsupported non numeric type");
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (unaryFn) {
|
|
case UnaryFn::exp:
|
|
return builder.create<math::ExpOp>(arg.getLoc(), arg);
|
|
case UnaryFn::log:
|
|
return builder.create<math::LogOp>(arg.getLoc(), arg);
|
|
case UnaryFn::abs:
|
|
return builder.create<math::AbsFOp>(arg.getLoc(), arg);
|
|
case UnaryFn::ceil:
|
|
return builder.create<math::CeilOp>(arg.getLoc(), arg);
|
|
case UnaryFn::floor:
|
|
return builder.create<math::FloorOp>(arg.getLoc(), arg);
|
|
case UnaryFn::negf:
|
|
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
|
|
case UnaryFn::reciprocal: {
|
|
Attribute oneAttr = builder.getOneAttr(arg.getType());
|
|
auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
|
|
::cast<TypedAttr>(oneAttr));
|
|
return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
|
|
}
|
|
case UnaryFn::round:
|
|
return builder.create<math::RoundOp>(arg.getLoc(), arg);
|
|
case UnaryFn::sqrt:
|
|
return builder.create<math::SqrtOp>(arg.getLoc(), arg);
|
|
case UnaryFn::rsqrt:
|
|
return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
|
|
case UnaryFn::square:
|
|
return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
|
|
case UnaryFn::tanh:
|
|
return builder.create<math::TanhOp>(arg.getLoc(), arg);
|
|
case UnaryFn::erf:
|
|
return builder.create<math::ErfOp>(arg.getLoc(), arg);
|
|
}
|
|
llvm_unreachable("unsupported unary function");
|
|
}
|
|
|
|
// Build the binary functions defined by OpDSL.
|
|
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
|
|
bool allComplex = isComplex(arg0) && isComplex(arg1);
|
|
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
|
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
|
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
|
|
arg1.getType().getIntOrFloatBitWidth() == 1;
|
|
if (!allComplex && !allFloatingPoint && !allInteger)
|
|
llvm_unreachable("unsupported non numeric type");
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (binaryFn) {
|
|
case BinaryFn::add:
|
|
if (allComplex)
|
|
return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::sub:
|
|
if (allComplex)
|
|
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
llvm_unreachable("unsupported operation: sub with bools");
|
|
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::mul:
|
|
if (allComplex)
|
|
return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::div:
|
|
if (allComplex)
|
|
return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
llvm_unreachable("unsupported operation: div with bools");
|
|
return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::div_unsigned:
|
|
if (!allInteger || allBool)
|
|
llvm_unreachable("unsupported operation: unsigned div not on uint");
|
|
return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::powf:
|
|
assert(allFloatingPoint);
|
|
return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
|
|
}
|
|
llvm_unreachable("unsupported binary function");
|
|
}
|
|
|
|
// Build the ternary functions defined by OpDSL.
|
|
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
|
|
Value arg2) {
|
|
bool headBool =
|
|
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
|
|
bool tailFloatingPoint =
|
|
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
|
|
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (ternaryFn) {
|
|
case TernaryFn::select:
|
|
if (!headBool && !(tailFloatingPoint || tailInteger))
|
|
llvm_unreachable("unsupported non numeric type");
|
|
return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
|
|
}
|
|
llvm_unreachable("unsupported ternary function");
|
|
}
|
|
|
|
// Build the type functions defined by OpDSL.
|
|
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
|
switch (typeFn) {
|
|
case TypeFn::cast_signed:
|
|
return cast(toType, operand, false);
|
|
case TypeFn::cast_unsigned:
|
|
return cast(toType, operand, true);
|
|
}
|
|
llvm_unreachable("unsupported type conversion function");
|
|
}
|
|
|
|
void yieldOutputs(ValueRange values) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
Location loc = builder.getUnknownLoc();
|
|
builder.create<YieldOp>(loc, values);
|
|
}
|
|
|
|
Value constant(const std::string &value) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
Location loc = builder.getUnknownLoc();
|
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
|
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
|
|
}
|
|
|
|
Value index(int64_t dim) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
|
}
|
|
|
|
Type getIntegerType(unsigned width) {
|
|
return IntegerType::get(builder.getContext(), width);
|
|
}
|
|
|
|
Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
|
|
Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
|
|
|
|
private:
|
|
// Generates operations to cast the given operand to a specified type.
|
|
// If the cast cannot be performed, a warning will be issued and the
|
|
// operand returned as-is (which will presumably yield a verification
|
|
// issue downstream).
|
|
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
auto loc = operand.getLoc();
|
|
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
|
|
}
|
|
|
|
bool isComplex(Value value) {
|
|
return llvm::isa<ComplexType>(value.getType());
|
|
}
|
|
bool isFloatingPoint(Value value) {
|
|
return llvm::isa<FloatType>(value.getType());
|
|
}
|
|
bool isInteger(Value value) {
|
|
return llvm::isa<IntegerType>(value.getType());
|
|
}
|
|
|
|
OpBuilder &builder;
|
|
Block █
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CopyOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
struct EraseSelfCopy : OpRewritePattern<CopyOp> {
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(CopyOp copyOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (copyOp.getInputs() != copyOp.getOutputs())
|
|
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
|
|
if (copyOp.hasPureBufferSemantics())
|
|
rewriter.eraseOp(copyOp);
|
|
else
|
|
rewriter.replaceOp(copyOp, copyOp.getInputs());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseSelfCopy>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FillOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
|
|
///
|
|
/// For such op chains, we can create new linalg.fill ops with the result
|
|
/// type of the tensor.expand/collapse_shape op.
|
|
template <typename TensorReshapeOp>
|
|
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
|
|
if (!oldFill)
|
|
return failure();
|
|
|
|
Location loc = oldFill.getLoc();
|
|
TensorReshapeOp newInit;
|
|
if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
|
|
|
|
newInit = rewriter.create<TensorReshapeOp>(
|
|
loc, reshapeOp.getResultType(), oldFill.output(),
|
|
reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
|
|
reshapeOp.getStaticOutputShape());
|
|
} else {
|
|
newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
|
|
oldFill.output(),
|
|
reshapeOp.getReassociation());
|
|
}
|
|
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
|
|
ValueRange{newInit});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue || fillOp.value() != padValue)
|
|
return failure();
|
|
|
|
ReifiedRankedShapedTypeDims reifiedShape;
|
|
if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
|
|
return rewriter.notifyMatchFailure(
|
|
padOp, "failed to reify tensor.pad op result shape");
|
|
|
|
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
|
|
padOp.getLoc(), reifiedShape.front(),
|
|
padOp.getResultType().getElementType());
|
|
Value replacement =
|
|
rewriter
|
|
.create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
|
|
ValueRange{emptyTensor})
|
|
.getResult(0);
|
|
if (replacement.getType() != padOp.getResultType()) {
|
|
replacement = rewriter.create<tensor::CastOp>(
|
|
fillOp.getLoc(), padOp.getResultType(), replacement);
|
|
}
|
|
rewriter.replaceOp(padOp, replacement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
|
|
/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
|
|
if (!srcPadOp)
|
|
return failure();
|
|
|
|
if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Walk back the tensor.insert_slice chain and find the first destination
|
|
// value at the start of the chain.
|
|
Value firstDest = insertOp.getDest();
|
|
while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
|
|
if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Make sure the range of values accessed are disjoint. Without this, we
|
|
// cannot fold tensor.pad away.
|
|
bool disjoint = false;
|
|
for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
|
|
// If the dimension has dynamic offset/size, we cannot guarantee
|
|
// disjoint. So just skip it.
|
|
if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
|
|
insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
|
|
prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
|
|
continue;
|
|
|
|
// Get the range start and end, inclusively for both.
|
|
int64_t prevStart = prevOp.getStaticOffset(i);
|
|
int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
|
|
prevOp.getStaticStride(i);
|
|
int64_t nextStart = insertOp.getStaticOffset(i);
|
|
int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
|
|
insertOp.getStaticStride(i);
|
|
if (prevEnd < nextStart || nextEnd < prevStart) {
|
|
disjoint = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!disjoint)
|
|
break;
|
|
firstDest = prevOp.getDest();
|
|
}
|
|
|
|
// Check whether the first destination is a fill op. For overlapped cases,
|
|
// this also cannot be true.
|
|
auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
|
|
if (!dstFillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = srcPadOp.getConstantPaddingValue();
|
|
if (!padValue || dstFillOp.value() != padValue)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
|
|
SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
|
|
|
|
Location loc = insertOp.getLoc();
|
|
MLIRContext *context = getContext();
|
|
|
|
AffineExpr sym0, sym1;
|
|
bindSymbols(context, sym0, sym1);
|
|
auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
|
|
|
|
// Calculate the new offsets for the insert. It should be the old offsets
|
|
// plus low padding sizes.
|
|
SmallVector<OpFoldResult, 4> newOffsets;
|
|
for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
|
|
newOffsets.push_back(affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
|
|
}
|
|
|
|
RankedTensorType srcPadType = srcPadOp.getSourceType();
|
|
SmallVector<OpFoldResult, 4> newSizes;
|
|
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
|
|
if (srcPadType.isDynamicDim(i)) {
|
|
newSizes.push_back(
|
|
rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
|
|
.getResult());
|
|
} else {
|
|
newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
|
|
insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
|
|
newSizes, insertOp.getMixedStrides());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.extract(linalg.fill(<input>)) into <input>
|
|
struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
|
|
public:
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// See if tensor input of tensor.extract op is the result of a linalg.fill
|
|
// op.
|
|
auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// Get scalar input operand of linalg.fill op.
|
|
Value extractedScalar = fillOp.getInputs()[0];
|
|
|
|
// Replace tensor.extract op with scalar value used to fill the tensor.
|
|
rewriter.replaceOp(extractOp, extractedScalar);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Folds pack(fill) into a single fill op if
|
|
/// 1. The pack op does not have padding value, or
|
|
/// 2. The filled value and padding value are the same.
|
|
static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
|
|
tensor::PackOp packOp) {
|
|
auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
if (auto paddingValue = packOp.getPaddingValue())
|
|
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
|
|
return failure();
|
|
|
|
Value packOpDest = packOp.getDest();
|
|
if (!packOpDest.hasOneUse())
|
|
return failure();
|
|
|
|
return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
|
|
packOp.getDest());
|
|
}
|
|
|
|
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
|
|
struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
|
|
public:
|
|
FoldFillWithPack(MLIRContext *context)
|
|
: OpRewritePattern<tensor::PackOp>(context) {}
|
|
|
|
LogicalResult matchAndRewrite(tensor::PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
|
|
if (failed(fillOp))
|
|
return failure();
|
|
rewriter.replaceOp(packOp, fillOp.value().result());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold fill with copy.
|
|
struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
|
|
using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
|
|
fillOp.getInputs(),
|
|
copyOp.getOutputs());
|
|
return success();
|
|
}
|
|
if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
|
|
fillOp.getOutputs());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold fill with transpose.
|
|
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<FillOp>(
|
|
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
|
|
transposeOp.getDpsInitOperand(0)->get());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results
|
|
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
|
|
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
|
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
|
|
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GenericOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void buildGenericRegion(
|
|
OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs,
|
|
ValueRange outputs,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
|
SmallVector<Type, 4> blockArgTypes;
|
|
SmallVector<Location, 4> blockArgLocs;
|
|
for (ValueRange container : {inputs, outputs}) {
|
|
for (Value v : container) {
|
|
Type t = v.getType();
|
|
blockArgTypes.push_back(
|
|
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
|
|
blockArgLocs.push_back(v.getLoc());
|
|
}
|
|
}
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
Block *bodyBlock =
|
|
builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
|
|
bodyBuild(builder, loc, bodyBlock->getArguments());
|
|
}
|
|
|
|
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
for (Value v : getRegionOutputArgs())
|
|
setNameFn(v, "out");
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
|
|
ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall);
|
|
result.addAttributes(attributes);
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, outputs, bodyBuild);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
|
|
StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs,
|
|
builder.getAffineMapArrayAttr(indexingMaps),
|
|
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
|
|
iteratorTypes,
|
|
[&](utils::IteratorType iter) -> mlir::Attribute {
|
|
return IteratorTypeAttr::get(builder.getContext(), iter);
|
|
}))),
|
|
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
|
|
libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
|
|
bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
|
|
StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall, bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
|
|
// Print extra attributes.
|
|
auto genericAttrNames = linalgTraitAttrNames();
|
|
|
|
llvm::StringSet<> genericAttrNamesSet;
|
|
genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
|
|
SmallVector<NamedAttribute, 8> genericAttrs;
|
|
for (auto attr : (*this)->getAttrs()) {
|
|
if (attr.getName() == getIteratorTypesAttrName()) {
|
|
auto iteratorTypes =
|
|
llvm::cast<ArrayAttr>(attr.getValue())
|
|
.getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
|
|
// Convert IteratorType enums into the string representation. This is
|
|
// needed, because tests still use the old format when 'iterator_types'
|
|
// attribute is represented as an array of strings.
|
|
// TODO: Remove this conversion once tests are fixed.
|
|
SmallVector<Attribute> iteratorTypeNames =
|
|
llvm::to_vector(llvm::map_range(
|
|
iteratorTypes, [&](utils::IteratorType t) -> Attribute {
|
|
return StringAttr::get(getContext(), stringifyIteratorType(t));
|
|
}));
|
|
|
|
genericAttrs.emplace_back(
|
|
getIteratorTypesAttrName(),
|
|
ArrayAttr::get(getContext(), iteratorTypeNames));
|
|
} else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
|
|
genericAttrs.push_back(attr);
|
|
}
|
|
}
|
|
if (!genericAttrs.empty()) {
|
|
auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
|
|
p << genericDictAttr;
|
|
}
|
|
|
|
// Printing is shared with named ops, except for the region and attributes
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
|
|
genericAttrNames.push_back("operandSegmentSizes");
|
|
genericAttrNamesSet.insert(genericAttrNames.back());
|
|
|
|
bool hasExtraAttrs = false;
|
|
for (NamedAttribute n : (*this)->getAttrs()) {
|
|
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
|
|
break;
|
|
}
|
|
if (hasExtraAttrs) {
|
|
p << " attrs = ";
|
|
p.printOptionalAttrDict((*this)->getAttrs(),
|
|
/*elidedAttrs=*/genericAttrNames);
|
|
}
|
|
|
|
// Print region.
|
|
if (!getRegion().empty()) {
|
|
p << ' ';
|
|
p.printRegion(getRegion());
|
|
}
|
|
|
|
// Print results.
|
|
printNamedStructuredOpResults(p, getResultTensors().getTypes());
|
|
}
|
|
|
|
ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
DictionaryAttr dictAttr;
|
|
// Parse the core linalg traits that must check into a dictAttr.
|
|
// The name is unimportant as we will overwrite result.attributes.
|
|
// The core linalg traits must contain the information necessary to pass the
|
|
// verifier.
|
|
llvm::SMLoc attributeLocation = parser.getCurrentLocation();
|
|
if (parser.parseAttribute(dictAttr, "_", result.attributes))
|
|
return failure();
|
|
result.attributes.assign(dictAttr.getValue().begin(),
|
|
dictAttr.getValue().end());
|
|
|
|
// Convert array of string into an array of IteratorType enums. This is
|
|
// needed, because tests still use the old format when 'iterator_types'
|
|
// attribute is represented as an array of strings.
|
|
// TODO: Remove this conversion once tests are fixed.
|
|
auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
|
|
result.attributes.get(getIteratorTypesAttrName(result.name)));
|
|
if (!iteratorTypes) {
|
|
return parser.emitError(attributeLocation)
|
|
<< "expected " << getIteratorTypesAttrName(result.name)
|
|
<< " array attribute";
|
|
}
|
|
|
|
SmallVector<Attribute> iteratorTypeAttrs;
|
|
|
|
for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
|
|
auto maybeIteratorType = utils::symbolizeIteratorType(s);
|
|
if (!maybeIteratorType.has_value())
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "unexpected iterator_type (" << s << ")";
|
|
|
|
iteratorTypeAttrs.push_back(
|
|
IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
|
|
}
|
|
result.attributes.set(getIteratorTypesAttrName(result.name),
|
|
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
|
|
|
|
// Parsing is shared with named ops, except for the region.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// Optional attributes may be added.
|
|
if (succeeded(parser.parseOptionalKeyword("attrs")))
|
|
if (failed(parser.parseEqual()) ||
|
|
failed(parser.parseOptionalAttrDict(result.attributes)))
|
|
return failure();
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parser.parseRegion(*region, {}))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
// Generic ops may specify that a subset of its outputs are tensors. Such
|
|
// outputs are specified in the result type.
|
|
// TODO: may need to move output parsing before region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void getGenericEffectsImpl(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects,
|
|
LinalgOp linalgOp) {
|
|
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
|
|
if (!llvm::isa<MemRefType>(operand.getType()))
|
|
continue;
|
|
effects.emplace_back(
|
|
MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
|
|
/*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
|
|
if (!llvm::isa<MemRefType>(operand.get().getType()))
|
|
continue;
|
|
if (linalgOp.payloadUsesValueFromOperand(&operand)) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
void GenericOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
LogicalResult GenericOp::verify() { return success(); }
|
|
|
|
namespace {
|
|
|
|
/// Remove any linalg operation (on tensors) that are just copying
|
|
/// the values from inputs to the results. Requirements are
|
|
/// 1) All iterator types are parallel
|
|
/// 2) The body contains just a yield operation with the yielded values being
|
|
/// the arguments corresponding to the operands.
|
|
template <typename OpTy>
|
|
struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Check all indexing maps are identity.
|
|
if (llvm::any_of(linalgOp.getIndexingMapsArray(),
|
|
[](AffineMap map) { return !map.isIdentity(); }))
|
|
return failure();
|
|
|
|
// Check that the body of the linalg operation is just a linalg.yield
|
|
// operation.
|
|
Block &body = linalgOp->getRegion(0).front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return failure();
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return failure();
|
|
|
|
// In the buffer case, we need to check exact buffer equality.
|
|
if (linalgOp.hasPureBufferSemantics()) {
|
|
if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
|
|
linalgOp.getDpsInputOperand(0)->get() ==
|
|
linalgOp.getDpsInitOperand(0)->get()) {
|
|
rewriter.eraseOp(linalgOp);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// Mixed semantics is not supported yet.
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return failure();
|
|
|
|
// Get the argument number of the returned values. That is the operand
|
|
// number to use for replacing uses of this operation.
|
|
SmallVector<Value> returnedArgs;
|
|
for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
|
|
auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return failure();
|
|
unsigned argumentNumber = yieldArg.getArgNumber();
|
|
Value returnedArg = linalgOp->getOperand(argumentNumber);
|
|
Type resultType = linalgOp->getResult(yieldVal.index()).getType();
|
|
// The input can have a different type than the result, e.g. a dynamic
|
|
// input dimension can be turned into a static output dimension.
|
|
Type returnType = returnedArg.getType();
|
|
if (returnType != resultType) {
|
|
// Distinguish between sparse conversion or dense tensor casting.
|
|
// TODO: unify the two ops?
|
|
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
|
|
sparse_tensor::getSparseTensorEncoding(resultType))
|
|
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
|
|
linalgOp.getLoc(), resultType, returnedArg);
|
|
else {
|
|
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
|
|
resultType))
|
|
return failure();
|
|
returnedArg = rewriter.create<tensor::CastOp>(
|
|
linalgOp.getLoc(), resultType, returnedArg);
|
|
}
|
|
}
|
|
returnedArgs.push_back(returnedArg);
|
|
}
|
|
|
|
if (returnedArgs.size() != linalgOp->getNumResults())
|
|
return failure();
|
|
rewriter.replaceOp(linalgOp, returnedArgs);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseIdentityLinalgOp<GenericOp>>(context);
|
|
}
|
|
|
|
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MapOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseDstStyleOp(
|
|
OpAsmParser &parser, OperationState &result,
|
|
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
|
|
nullptr) {
|
|
// Parse `ins` and `outs`.
|
|
SmallVector<Type, 4> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
|
|
/*addOperandSegmentSizes=*/false))
|
|
return failure();
|
|
|
|
// Add result types.
|
|
for (Type outputType : outputTypes) {
|
|
if (llvm::isa<RankedTensorType>(outputType))
|
|
result.addTypes(outputType);
|
|
}
|
|
|
|
// Parse required attributes.
|
|
if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
|
|
return failure();
|
|
|
|
// Parse optional attributes.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void MapOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
}
|
|
|
|
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "mapped");
|
|
}
|
|
|
|
void MapOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, init);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, /*outputs=*/{}, bodyBuild);
|
|
}
|
|
|
|
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
|
|
const OperationName &payloadOpName,
|
|
const NamedAttrList &payloadOpAttrs,
|
|
ArrayRef<Value> operands,
|
|
bool initFirst = false) {
|
|
OpBuilder b(parser.getContext());
|
|
Region *body = result.addRegion();
|
|
Block &block = body->emplaceBlock();
|
|
b.setInsertionPointToStart(&block);
|
|
SmallVector<Value> bbArgs;
|
|
for (auto &operand : operands) {
|
|
block.addArgument(
|
|
llvm::cast<ShapedType>(operand.getType()).getElementType(),
|
|
b.getUnknownLoc());
|
|
}
|
|
SmallVector<Value> payloadOpOperands;
|
|
// If initFirst flag is enabled, we consider init as the first position of
|
|
// payload operands.
|
|
if (initFirst) {
|
|
payloadOpOperands.push_back(block.getArguments().back());
|
|
for (const auto &arg : block.getArguments().drop_back())
|
|
payloadOpOperands.push_back(arg);
|
|
} else {
|
|
payloadOpOperands = {block.getArguments().begin(),
|
|
block.getArguments().end()};
|
|
}
|
|
|
|
Operation *payloadOp = b.create(
|
|
result.location, b.getStringAttr(payloadOpName.getStringRef()),
|
|
payloadOpOperands,
|
|
TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
|
|
.getElementType()},
|
|
payloadOpAttrs);
|
|
b.create<YieldOp>(result.location, payloadOp->getResults());
|
|
}
|
|
|
|
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
std::optional<OperationName> payloadOpName;
|
|
NamedAttrList payloadOpAttrs;
|
|
if (succeeded(parser.parseOptionalLBrace())) {
|
|
FailureOr<OperationName> operationName = parser.parseCustomOperationName();
|
|
if (failed(operationName))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(payloadOpAttrs))
|
|
return failure();
|
|
payloadOpName = operationName.value();
|
|
if (parser.parseRBrace())
|
|
return failure();
|
|
}
|
|
|
|
if (parseDstStyleOp(parser, result))
|
|
return failure();
|
|
|
|
if (payloadOpName.has_value()) {
|
|
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
|
|
ArrayRef(result.operands).drop_back());
|
|
} else {
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
|
|
/*allowType=*/true, /*allowAttrs=*/true)) {
|
|
return failure();
|
|
}
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, regionArgs))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Retrieve the operation from the body, if it is the only one (except
|
|
// yield) and if it gets the same amount of arguments as the body does.
|
|
// If initFirst flag is enabled, we check that init takes the first position in
|
|
// operands of payload.
|
|
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
|
|
if (body->getOperations().size() != 2)
|
|
return nullptr;
|
|
Operation &payload = body->getOperations().front();
|
|
assert(isa<YieldOp>(body->getOperations().back()));
|
|
|
|
if (payload.getNumOperands() == 0 ||
|
|
payload.getNumOperands() != body->getNumArguments())
|
|
return nullptr;
|
|
if (initFirst) {
|
|
// check init
|
|
if (payload.getOperands().back() != body->getArgument(0))
|
|
return nullptr;
|
|
// check rest
|
|
for (const auto &[operand, bbArg] :
|
|
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
|
|
if (bbArg != operand)
|
|
return nullptr;
|
|
}
|
|
} else {
|
|
for (const auto &[operand, bbArg] :
|
|
llvm::zip(payload.getOperands(), body->getArguments())) {
|
|
if (bbArg != operand)
|
|
return nullptr;
|
|
}
|
|
}
|
|
return &payload;
|
|
}
|
|
|
|
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
|
|
SmallVector<StringRef> elidedAttrs;
|
|
std::string attrToElide;
|
|
p << " { " << payloadOp->getName().getStringRef();
|
|
for (const auto &attr : payloadOp->getAttrs()) {
|
|
auto fastAttr =
|
|
llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
|
|
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
|
|
attrToElide = attr.getName().str();
|
|
elidedAttrs.push_back(attrToElide);
|
|
break;
|
|
}
|
|
}
|
|
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
|
|
p << " }";
|
|
}
|
|
|
|
void MapOp::print(OpAsmPrinter &p) {
|
|
Block *mapper = getBody();
|
|
Operation *payloadOp = findPayloadOp(mapper);
|
|
if (payloadOp) {
|
|
printShortForm(p, payloadOp);
|
|
}
|
|
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
|
|
if (!payloadOp) {
|
|
// Print region if the payload op was not detected.
|
|
p.increaseIndent();
|
|
p.printNewline();
|
|
p << "(";
|
|
llvm::interleaveComma(mapper->getArguments(), p,
|
|
[&](auto arg) { p.printRegionArgument(arg); });
|
|
p << ") ";
|
|
|
|
p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
|
|
p.decreaseIndent();
|
|
}
|
|
}
|
|
|
|
LogicalResult MapOp::verify() {
|
|
auto *bodyBlock = getBody();
|
|
auto blockArgs = bodyBlock->getArguments();
|
|
|
|
// Checks if the number of `inputs` match the arity of the `mapper` region.
|
|
if (getInputs().size() != blockArgs.size())
|
|
return emitOpError() << "expects number of operands to match the arity of "
|
|
"mapper, but got: "
|
|
<< getInputs().size() << " and " << blockArgs.size();
|
|
|
|
// The parameters of mapper should all match the element type of inputs.
|
|
for (const auto &[bbArgType, inputArg] :
|
|
llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
|
|
auto inputElemType =
|
|
llvm::cast<ShapedType>(inputArg.getType()).getElementType();
|
|
if (bbArgType != inputElemType) {
|
|
return emitOpError() << "expected element type of input " << inputElemType
|
|
<< " to match bbArg type " << bbArgType;
|
|
}
|
|
}
|
|
|
|
// The shape of each input must match the shape of the output.
|
|
auto outputShape = getInit().getType().getShape();
|
|
for (Type inputArgType : TypeRange{getInputs()}) {
|
|
auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
|
|
if (inputElemShape != outputShape) {
|
|
return emitOpError() << "expected shape of input (" << inputElemShape
|
|
<< ") to match shape of output (" << outputShape
|
|
<< ")";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr MapOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
int64_t numIndexingMaps = getOperands().size();
|
|
return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
|
|
numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
|
|
}
|
|
|
|
void MapOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReduceOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
for (Value v : getRegionOutputArgs())
|
|
setNameFn(v, "init");
|
|
}
|
|
|
|
void ReduceOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "reduced");
|
|
}
|
|
|
|
void ReduceOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange inits, ArrayRef<int64_t> dimensions,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, inits, dimensions);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
for (Value init : inits) {
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
}
|
|
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, inits, bodyBuild);
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
|
|
int64_t inputRank =
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
|
utils::IteratorType::parallel);
|
|
for (int64_t reductionDim : getDimensions())
|
|
iteratorTypes[reductionDim] = utils::IteratorType::reduction;
|
|
return iteratorTypes;
|
|
}
|
|
|
|
ArrayAttr ReduceOp::getIndexingMaps() {
|
|
int64_t inputRank =
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
|
|
SmallVector<AffineMap> affineMaps(
|
|
getNumDpsInputs(),
|
|
AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
|
|
AffineMap resultMap =
|
|
AffineMap::getMultiDimIdentityMap(inputRank, getContext())
|
|
.dropResults(getDimensions());
|
|
for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
|
|
affineMaps.push_back(resultMap);
|
|
return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
|
|
}
|
|
|
|
void ReduceOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
|
|
NamedAttrList &attributes,
|
|
StringRef attributeName) {
|
|
if (parser.parseKeyword(attributeName) || parser.parseEqual())
|
|
return failure();
|
|
|
|
attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
|
|
return success();
|
|
}
|
|
|
|
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
std::optional<OperationName> payloadOpName;
|
|
NamedAttrList payloadOpAttrs;
|
|
if (succeeded(parser.parseOptionalLBrace())) {
|
|
FailureOr<OperationName> operationName = parser.parseCustomOperationName();
|
|
if (failed(operationName))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(payloadOpAttrs))
|
|
return failure();
|
|
payloadOpName = operationName.value();
|
|
if (parser.parseRBrace())
|
|
return failure();
|
|
}
|
|
|
|
if (parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
|
|
}))
|
|
return failure();
|
|
|
|
if (payloadOpName.has_value()) {
|
|
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
|
|
ArrayRef(result.operands), /*initFirst=*/true);
|
|
} else {
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
|
|
/*allowType=*/true, /*allowAttrs=*/true)) {
|
|
return failure();
|
|
}
|
|
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, regionArgs))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
|
|
ArrayRef<int64_t> attributeValue) {
|
|
p << ' ' << attributeName << " = [" << attributeValue << "] ";
|
|
}
|
|
|
|
void ReduceOp::print(OpAsmPrinter &p) {
|
|
Block *mapper = getBody();
|
|
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
|
|
if (payloadOp) {
|
|
printShortForm(p, payloadOp);
|
|
}
|
|
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
|
if (!payloadOp) {
|
|
// Print region if the payload op was not detected.
|
|
p.increaseIndent();
|
|
p.printNewline();
|
|
p << "(";
|
|
llvm::interleaveComma(mapper->getArguments(), p,
|
|
[&](auto arg) { p.printRegionArgument(arg); });
|
|
p << ") ";
|
|
|
|
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
|
|
p.decreaseIndent();
|
|
}
|
|
}
|
|
|
|
LogicalResult ReduceOp::verify() {
|
|
ArrayRef<int64_t> dimensionsRef = getDimensions();
|
|
|
|
for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
|
|
if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
|
|
return emitOpError() << "expects all inputs to have the same shapes. "
|
|
"Shape at input-index "
|
|
<< i
|
|
<< " is not equal to the shape at input-index 0.";
|
|
}
|
|
}
|
|
for (int64_t i = 1; i < getNumDpsInits(); ++i) {
|
|
if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
|
|
llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
|
|
return emitOpError() << "expects all outputs to have the same shapes. "
|
|
"Shape at output-index "
|
|
<< i
|
|
<< " is not equal to the shape at output-index 0.";
|
|
}
|
|
}
|
|
auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
|
|
auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
|
|
|
|
DenseSet<int64_t> dimensionsToReduce;
|
|
for (int64_t dimension : dimensionsRef) {
|
|
if (dimension < 0 || dimension >= inputType.getRank()) {
|
|
return emitOpError()
|
|
<< "dimensions for reduction should be in the range [0, "
|
|
<< inputType.getRank() - 1 << "].";
|
|
}
|
|
dimensionsToReduce.insert(dimension);
|
|
}
|
|
|
|
auto inputDims = inputType.getShape();
|
|
auto initDims = initType.getShape();
|
|
|
|
// Input dimensions that will be left after the reduction.
|
|
SmallVector<int64_t> reducedInputDims;
|
|
for (const auto &en : llvm::enumerate(inputDims)) {
|
|
if (!dimensionsToReduce.count(en.index()))
|
|
reducedInputDims.push_back(en.value());
|
|
}
|
|
|
|
if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
|
|
return emitOpError() << "number of dimensions after reduction "
|
|
<< reducedInputDims.size()
|
|
<< " doesn't match the init rank "
|
|
<< initType.getRank();
|
|
}
|
|
|
|
if (reducedInputDims != initDims)
|
|
return emitOpError() << "init dimensions [" << initDims
|
|
<< "] doesn't match input dimensions after reduction ["
|
|
<< reducedInputDims << "]";
|
|
|
|
Block *block = getBody();
|
|
if (block->getNumArguments() != this->getNumOperands())
|
|
return emitOpError()
|
|
<< "mismatching number of operands and block arguments";
|
|
|
|
// Check that the first block arguments match the element type of the inputs.
|
|
for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
|
|
Type inputElementType =
|
|
llvm::cast<ShapedType>(input.getType()).getElementType();
|
|
if (inputElementType != bbArg.getType())
|
|
return emitOpError()
|
|
<< "input element type " << inputElementType
|
|
<< " does not match corresponding block argument type "
|
|
<< bbArg.getType();
|
|
}
|
|
|
|
// Check that the last block arguments match the element type of the outputs.
|
|
for (auto [output, bbArg] : llvm::zip(
|
|
getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
|
|
auto outputElementType =
|
|
llvm::cast<ShapedType>(output.getType()).getElementType();
|
|
if (outputElementType != bbArg.getType())
|
|
return emitOpError()
|
|
<< "output element type " << outputElementType
|
|
<< " does not match corresponding block argument type "
|
|
<< bbArg.getType();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransposeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void buildIdentityRegion(OpBuilder &builder, Location loc,
|
|
Region ®ion, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
buildGenericRegion(builder, loc, region, inputs, outputs,
|
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
b.create<linalg::YieldOp>(loc, args[0]);
|
|
});
|
|
}
|
|
|
|
void TransposeOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
DenseI64ArrayAttr permutation,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addOperands(input);
|
|
result.addOperands(init);
|
|
result.addAttribute(getPermutationAttrName(result.name), permutation);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
|
|
init);
|
|
}
|
|
|
|
void TransposeOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
ArrayRef<int64_t> permutation,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
|
|
attributes);
|
|
}
|
|
|
|
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (failed(parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "permutation");
|
|
})))
|
|
return failure();
|
|
|
|
OpBuilder builder(parser.getContext());
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(),
|
|
/*inputs=*/result.operands,
|
|
/*outputs=*/{});
|
|
return success();
|
|
}
|
|
|
|
void TransposeOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "transposed");
|
|
}
|
|
|
|
void TransposeOp::print(OpAsmPrinter &p) {
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
|
|
}
|
|
|
|
LogicalResult TransposeOp::verify() {
|
|
ArrayRef<int64_t> permutationRef = getPermutation();
|
|
|
|
if (!isPermutationVector(permutationRef))
|
|
return emitOpError("permutation is not valid");
|
|
|
|
auto inputType = getInput().getType();
|
|
auto initType = getInit().getType();
|
|
|
|
int64_t rank = inputType.getRank();
|
|
|
|
if (rank != initType.getRank())
|
|
return emitOpError() << "input rank " << rank
|
|
<< " does not match init rank " << initType.getRank();
|
|
|
|
if (rank != static_cast<int64_t>(permutationRef.size()))
|
|
return emitOpError() << "size of permutation " << permutationRef.size()
|
|
<< " does not match the argument rank " << rank;
|
|
|
|
auto inputDims = inputType.getShape();
|
|
auto initDims = initType.getShape();
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
int64_t inputDim = inputDims[permutationRef[i]];
|
|
int64_t initDim = initDims[i];
|
|
|
|
if (inputDim != initDim) {
|
|
return emitOpError() << "dim(result, " << i << ") = " << initDim
|
|
<< " doesn't match dim(input, permutation[" << i
|
|
<< "]) = " << inputDim;
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr TransposeOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
return builder.getAffineMapArrayAttr(
|
|
{inversePermutation(AffineMap::getPermutationMap(
|
|
llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
|
|
builder.getMultiDimIdentityMap(rank)});
|
|
}
|
|
|
|
void TransposeOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &result) {
|
|
// Single dimension transpose.
|
|
if (getPermutation().size() == 0) {
|
|
result.push_back(getInput());
|
|
return success();
|
|
}
|
|
// Identity permutation.
|
|
if (isIdentityPermutation(getPermutation())) {
|
|
result.push_back(getInput());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
/// Fold transpose with transpose.
|
|
struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
|
|
if (!defTransposeOp)
|
|
return failure();
|
|
ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
|
|
ArrayRef<int64_t> perms = transposeOp.getPermutation();
|
|
SmallVector<int64_t> foldedPerms;
|
|
foldedPerms.reserve(perms.size());
|
|
for (int64_t perm : perms)
|
|
foldedPerms.push_back(defPerms[perm]);
|
|
|
|
rewriter.replaceOpWithNewOp<TransposeOp>(
|
|
transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
|
|
foldedPerms);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<FoldTransposeWithTranspose>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BroadcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void BroadcastOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
DenseI64ArrayAttr dimensions,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addOperands(input);
|
|
result.addOperands(init);
|
|
result.addAttribute(getDimensionsAttrName(result.name), dimensions);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
|
|
init);
|
|
}
|
|
|
|
void BroadcastOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
ArrayRef<int64_t> dimensions,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
|
|
attributes);
|
|
}
|
|
|
|
ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (failed(parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
|
|
})))
|
|
return failure();
|
|
|
|
OpBuilder builder(parser.getContext());
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(),
|
|
/*inputs=*/result.operands,
|
|
/*outputs=*/{});
|
|
return success();
|
|
}
|
|
|
|
void BroadcastOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "broadcasted");
|
|
}
|
|
|
|
void BroadcastOp::print(OpAsmPrinter &p) {
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
|
}
|
|
|
|
LogicalResult BroadcastOp::verify() {
|
|
ArrayRef<int64_t> dimensionsRef = getDimensions();
|
|
|
|
auto inputType = getInput().getType();
|
|
auto initType = getInit().getType();
|
|
|
|
int64_t inputRank = inputType.getRank();
|
|
int64_t initRank = initType.getRank();
|
|
|
|
auto inputShape = inputType.getShape();
|
|
auto initShape = initType.getShape();
|
|
|
|
if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
|
|
return emitOpError() << "input rank plus added dimensions does not "
|
|
"match init rank. input rank: "
|
|
<< inputRank
|
|
<< ", dimensions size: " << dimensionsRef.size()
|
|
<< ", init rank: " << initRank;
|
|
|
|
for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
|
|
if (dim < 0 || dim >= initRank)
|
|
return emitOpError() << "dimension " << idx
|
|
<< " is out of range. expected range: [0, "
|
|
<< initRank - 1 << "], got: " << dim;
|
|
}
|
|
|
|
// Mapping from input dims to init dims.
|
|
SmallVector<int64_t> dimMap;
|
|
for (auto dim : llvm::seq<int64_t>(0, initRank)) {
|
|
if (!llvm::is_contained(dimensionsRef, dim))
|
|
dimMap.push_back(dim);
|
|
}
|
|
|
|
for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
|
|
// This dimensions is mapped from the input. Init and input dims should
|
|
// match.
|
|
if (inputShape[inputDimIdx] != initShape[initDimIdx])
|
|
return emitOpError() << "input dim " << inputDimIdx
|
|
<< " should match init dim " << initDimIdx
|
|
<< ". input: " << inputShape[inputDimIdx]
|
|
<< ", init: " << initShape[initDimIdx];
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr BroadcastOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
return builder.getAffineMapArrayAttr(
|
|
{builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
|
|
builder.getMultiDimIdentityMap(rank)});
|
|
}
|
|
|
|
void BroadcastOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// YieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void linalg::YieldOp::print(OpAsmPrinter &p) {
|
|
if (getNumOperands() > 0)
|
|
p << ' ' << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
if (getNumOperands() > 0)
|
|
p << " : " << getOperandTypes();
|
|
}
|
|
|
|
ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
|
|
SmallVector<Type, 2> types;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
return failure(parser.parseOperandList(opInfo) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
|
|
parser.resolveOperands(opInfo, types, loc, result.operands));
|
|
}
|
|
|
|
// Check the operand number and types must match the element types of the
|
|
// LinalgOp interface's shaped operands.
|
|
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
|
|
if (op.getNumOperands() != linalgOp.getNumDpsInits())
|
|
return op.emitOpError("expected number of yield values (")
|
|
<< op.getNumOperands()
|
|
<< ") to match the number of inits / outs operands of the enclosing "
|
|
<< "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
|
|
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
OpOperand *outputOperand =
|
|
linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
|
|
Type elementType = outputOperand->get().getType();
|
|
if (isa<MemRefType, RankedTensorType>(elementType))
|
|
elementType = getElementTypeOrSelf(outputOperand->get().getType());
|
|
if (opOperand.get().getType() != elementType)
|
|
return op.emitOpError("type of yield operand ")
|
|
<< (opOperand.getOperandNumber() + 1) << " ("
|
|
<< opOperand.get().getType() << ") doesn't match "
|
|
<< "the element type of the enclosing linalg.generic op ("
|
|
<< elementType << ")";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult linalg::YieldOp::verify() {
|
|
auto *parentOp = (*this)->getParentOp();
|
|
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
|
|
return emitOpError("expected single non-empty parent region");
|
|
|
|
if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
|
|
return verifyYield(*this, linalgOp);
|
|
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult IndexOp::verify() {
|
|
auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
|
|
if (!linalgOp)
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
if (linalgOp.getNumLoops() <= getDim())
|
|
return emitOpError("expected dim (")
|
|
<< getDim() << ") to be lower than the number of loops ("
|
|
<< linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
|
|
return success();
|
|
}
|
|
|
|
/////// Operations corresponding to library calls defined with Tablegen ////////
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
|
|
AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
|
|
unsigned rank,
|
|
MLIRContext *context) {
|
|
if (maybeMap)
|
|
return *maybeMap;
|
|
if (rank == 0)
|
|
return AffineMap::get(context);
|
|
return AffineMap::getMultiDimIdentityMap(rank, context);
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4>
|
|
mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
|
|
MLIRContext *context) {
|
|
SmallVector<AffineExpr, 4> res;
|
|
res.reserve(num);
|
|
for (unsigned i = 0; i < num; ++i)
|
|
res.push_back(getAffineDimExpr(startIdx++, context));
|
|
return res;
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
|
|
ArrayRef<AffineExpr> b) {
|
|
auto rangeA = llvm::make_range(a.begin(), a.end());
|
|
auto rangeB = llvm::make_range(b.begin(), b.end());
|
|
auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
|
|
return llvm::to_vector<4>(concatRanges);
|
|
}
|
|
|
|
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
|
|
if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
|
|
ss << "view";
|
|
for (auto size : memref.getShape())
|
|
if (size < 0)
|
|
ss << "sx";
|
|
else
|
|
ss << size << "x";
|
|
if (failed(appendMangledType(ss, memref.getElementType())))
|
|
return failure();
|
|
if (auto as = memref.getMemorySpace()) {
|
|
if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
|
|
ss << "as" << attr.getInt();
|
|
else
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
if (auto vec = llvm::dyn_cast<VectorType>(t)) {
|
|
ss << "vector";
|
|
llvm::interleave(
|
|
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
|
|
if (failed(appendMangledType(ss, vec.getElementType())))
|
|
return failure();
|
|
return success();
|
|
}
|
|
if (t.isSignlessIntOrIndexOrFloat()) {
|
|
ss << t;
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
std::string mlir::linalg::generateLibraryCallName(Operation *op) {
|
|
assert(isa<LinalgOp>(op));
|
|
std::string name(op->getName().getStringRef().str());
|
|
std::string fun = "";
|
|
for (NamedAttribute kv : op->getAttrs()) {
|
|
if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
|
|
fun = stringifyEnum(ufa.getValue()).str() + "_";
|
|
} else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
|
|
fun = stringifyEnum(bfa.getValue()).str() + "_";
|
|
}
|
|
}
|
|
name.reserve(128);
|
|
std::replace(name.begin(), name.end(), '.', '_');
|
|
llvm::raw_string_ostream ss(name);
|
|
ss << "_" << fun;
|
|
for (Type t : op->getOperandTypes()) {
|
|
if (failed(appendMangledType(ss, t)))
|
|
return std::string();
|
|
ss << "_";
|
|
}
|
|
std::string res = ss.str();
|
|
res.pop_back();
|
|
return res;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Canonicalizers and Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
// Linalg "inputs" may be either tensor or memref type.
|
|
// tensor<0xelt_type> is a convention that may not always mean
|
|
// "0 iterations". Only erase in cases we see memref<...x0x...>.
|
|
auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
|
|
if (!mt)
|
|
continue;
|
|
if (llvm::is_contained(op.getShape(&opOperand), 0)) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
|
|
/// result that is more static than the linalg op.
|
|
struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
|
|
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CastOp castOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::canFoldIntoProducerOp(castOp))
|
|
return failure();
|
|
|
|
auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
|
|
if (!linalgOp)
|
|
return failure();
|
|
|
|
// Cast can be in conditionally reachable region, if which case folding will
|
|
// generate invalid code. Only conservatively fold ops in same block for
|
|
// now.
|
|
if (castOp->getBlock() != linalgOp->getBlock())
|
|
return failure();
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
|
|
Location loc = linalgOp.getLoc();
|
|
OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
|
|
unsigned resultNumber = resultValue.getResultNumber();
|
|
auto resultType =
|
|
llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
|
|
// Replace the `outs` for the result with a `tensor.cast`. This cast is now
|
|
// going from a more dynamic shape to a less dynamic shape. If the producer
|
|
// for this cast, i.e. producer of the out operand, is also an operation
|
|
// that folds with tensor.cast consumer (like this pattern), the cast will
|
|
// continue to propagate as far up the stack as it can go.
|
|
OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
|
|
Value newOperand =
|
|
rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
|
|
SmallVector<Value> newOperands = linalgOp.getDpsInputs();
|
|
SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
|
|
linalgOp.getDpsInits().end());
|
|
outputOperands[resultNumber] = newOperand;
|
|
newOperands.append(outputOperands.begin(), outputOperands.end());
|
|
|
|
SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
|
|
linalgOp->result_type_end());
|
|
resultTypes[resultNumber] = resultType;
|
|
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
|
|
|
|
// Create a tensor.cast operation back to the original type.
|
|
Value castBack = rewriter.create<tensor::CastOp>(
|
|
loc, resultValue.getType(), newOp->getResult(resultNumber));
|
|
|
|
SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
|
|
results[resultNumber] = castBack;
|
|
rewriter.replaceOp(linalgOp, results);
|
|
rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// For each of the operand in `operands` this function maps the static sizes of
|
|
/// dimensions to their affine dim expressions.
|
|
static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
|
|
for (OpOperand &opOperand : operands) {
|
|
if (linalgOp.isScalar(&opOperand))
|
|
continue;
|
|
Value src = opOperand.get();
|
|
auto sourceType = llvm::cast<RankedTensorType>(src.getType());
|
|
auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
|
|
|
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
|
|
// `tensor.cast` operation and source of the cast operation has a static
|
|
// shape, then assign it to the `sourceShape`.
|
|
auto *parentOp = src.getDefiningOp();
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
if (parentOp) {
|
|
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
|
|
Value castSource = castOp.getSource();
|
|
auto castSourceType =
|
|
llvm::dyn_cast<RankedTensorType>(castSource.getType());
|
|
if (castSourceType && castSourceType.hasStaticShape())
|
|
sourceShape = castSourceType.getShape();
|
|
}
|
|
}
|
|
|
|
// If the source shape's dimension has a static shape, map the affine dim
|
|
// expression to the known static size.
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
if (sourceType.isDynamicDim(i))
|
|
continue;
|
|
if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
|
|
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
|
|
/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
|
|
/// their result types is stored in `resultTypes`. If `opOperand` requires no
|
|
/// change then `changeNeeded` is false and same operand is added in the
|
|
/// `newOperands` list.
|
|
static void createNewOperandWithStaticSizes(
|
|
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
|
|
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
|
|
bool &changeNeeded) {
|
|
Value src = opOperand->get();
|
|
newOperands.push_back(src);
|
|
if (linalgOp.isScalar(opOperand))
|
|
return;
|
|
auto sourceType = llvm::cast<RankedTensorType>(src.getType());
|
|
Type resultType = sourceType;
|
|
if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
|
|
resultTypes.push_back(resultType);
|
|
return;
|
|
}
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
|
|
SmallVector<int64_t> newShape;
|
|
// If operand is updated with new shape, `newOperandNeeded` will be
|
|
// true.
|
|
bool newOperandNeeded = false;
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
int64_t dimShape = sourceShape[i];
|
|
AffineExpr dimExpr = sourceMap.getResult(i);
|
|
if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
|
|
newShape.push_back(dimShape);
|
|
continue;
|
|
}
|
|
// Dimension has a dynamic shape and corresponding affine dim
|
|
// expression is present in the map. So assign the size for the
|
|
// given affine dim expression to the dimension.
|
|
newShape.push_back(affineExprToSize[dimExpr]);
|
|
newOperandNeeded = true;
|
|
}
|
|
resultType = RankedTensorType::get(newShape, sourceType.getElementType());
|
|
if (newOperandNeeded) {
|
|
changeNeeded = true;
|
|
// Get the new operand value given its size and element type by
|
|
// casting it.
|
|
Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
|
|
unsigned index = opOperand->getOperandNumber();
|
|
newOperands[index] = newOperand;
|
|
}
|
|
if (linalgOp.isDpsInit(opOperand))
|
|
resultTypes.push_back(resultType);
|
|
}
|
|
|
|
/// Static shapes for the operands can be inferred if any one of the operands
|
|
/// have a static shape. This can be done by referring to the affine dim
|
|
/// expressions for the operand.
|
|
struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return failure();
|
|
|
|
// Maps must be projected permutations.
|
|
if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
|
|
return !map.isProjectedPermutation();
|
|
}))
|
|
return failure();
|
|
|
|
// Maps affine dim expressions to the static size of that dimension.
|
|
llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
|
|
Location loc = linalgOp.getLoc();
|
|
|
|
// For each of the affine dim expression, check if the size is known. If
|
|
// known add that in the map.
|
|
populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
|
|
|
|
SmallVector<Value> newOperands;
|
|
SmallVector<Type> resultTypes;
|
|
|
|
// `changeNeeded` is `false` if the operands of `linalgOp` require no
|
|
// change in their types.
|
|
bool changeNeeded = false;
|
|
newOperands.reserve(linalgOp->getNumOperands());
|
|
resultTypes.reserve(linalgOp.getNumDpsInits());
|
|
|
|
// Iterate over all the operands and update the static sizes.
|
|
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
|
createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
|
|
affineExprToSize, linalgOp, newOperands,
|
|
resultTypes, changeNeeded);
|
|
}
|
|
|
|
// If the generic op has all the required static information, no
|
|
// canonicalization needed.
|
|
if (!changeNeeded)
|
|
return failure();
|
|
|
|
// Clone op.
|
|
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(newOp->getNumResults());
|
|
for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
|
|
Value newResult = std::get<1>(it);
|
|
Value oldResult = std::get<0>(it);
|
|
Type newType = newResult.getType();
|
|
Type oldType = oldResult.getType();
|
|
replacements.push_back(
|
|
(newType != oldType)
|
|
? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
|
|
: newResult);
|
|
}
|
|
rewriter.replaceOp(linalgOp, replacements);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// All named ops canonicalizers and folders are auto-generated in the
|
|
// .cpp.inc.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SoftmaxOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SoftmaxOp::verify() {
|
|
ShapedType inputType = getInputOperandType();
|
|
ShapedType outputType = getOutputOperandType();
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(inputShape, outputShape)))
|
|
return emitOpError("incompatible output shape");
|
|
|
|
int64_t inputRank = getInputOperandRank();
|
|
int64_t dimension = getDimension();
|
|
if ((dimension < 0) || (dimension >= inputRank))
|
|
return emitOpError("incorrect dimension specified");
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
|
|
int64_t operandRank = getInputOperandRank();
|
|
SmallVector<Range> loopBounds(operandRank);
|
|
Location loc = getLoc();
|
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
Value source = getInput();
|
|
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
|
loopBounds[dim].offset = zero;
|
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
|
loopBounds[dim].stride = one;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
|
|
SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
|
|
utils::IteratorType::parallel);
|
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
|
return iteratorTypes;
|
|
}
|
|
|
|
FailureOr<TilingResult>
|
|
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
int64_t rank = getInputOperandRank();
|
|
auto oneAttr = builder.getI64IntegerAttr(1);
|
|
SmallVector<OpFoldResult> strides(rank, oneAttr);
|
|
SmallVector<Value> tiledOperands;
|
|
tiledOperands.emplace_back(
|
|
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
|
|
tiledOperands.emplace_back(
|
|
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
|
|
|
|
SmallVector<Type, 4> resultTypes;
|
|
if (hasPureTensorSemantics())
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
|
|
}
|
|
|
|
LogicalResult SoftmaxOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
if (resultNumber == 0) {
|
|
resultOffsets.assign(offsets.begin(), offsets.end());
|
|
resultSizes.assign(sizes.begin(), sizes.end());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// cast(dynamic) -> static.
|
|
LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
LogicalResult
|
|
SoftmaxOp::reifyResultShapes(OpBuilder &b,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
SmallVector<OpFoldResult> shapes;
|
|
Location loc = getOperation()->getLoc();
|
|
IRRewriter rewriter(b);
|
|
auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
|
|
auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
|
|
for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
|
|
if (!outputShapedType.isDynamicDim(dim)) {
|
|
// Static dim: Return IntegerAttr.
|
|
shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
|
|
} else {
|
|
// Dynamic dim: Return Value.
|
|
OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
|
|
shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
|
|
}
|
|
}
|
|
reifiedReturnShapes.emplace_back(std::move(shapes));
|
|
return success();
|
|
}
|
|
|
|
void SoftmaxOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
|
|
if (!llvm::isa<MemRefType>(operand.getType()))
|
|
continue;
|
|
effects.emplace_back(MemoryEffects::Read::get(),
|
|
&getOperation()->getOpOperand(index), /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
for (OpOperand &operand : getDpsInitsMutable()) {
|
|
if (!llvm::isa<MemRefType>(operand.get().getType()))
|
|
continue;
|
|
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
// Helper functions for softmax decomposition.
|
|
// @{
|
|
|
|
// Helper function to produce the iterator types (reduction or parallel) and
|
|
// affine maps for the iterators used in the decomposition of softmax.
|
|
// This method creates:
|
|
// If allParallel == true:
|
|
// - iterator type: {parallel, ..., parallel}
|
|
// - affine maps:
|
|
// -- identity with inputRank dimensions.
|
|
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
|
// where N == inputRank.
|
|
//
|
|
// If allParallel == false:
|
|
// - iterator type at dim(i) == parallel for i != \p dim and
|
|
// dim(dim) == reduction.
|
|
// - affine map:
|
|
// -- identity with inputRank dimensions.
|
|
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
|
// where N == inputRank.
|
|
static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
|
|
computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
|
|
int64_t dim, bool allParallel = false) {
|
|
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
|
utils::IteratorType::parallel);
|
|
if (!allParallel)
|
|
iteratorTypes[dim] = utils::IteratorType::reduction;
|
|
MLIRContext *ctxt = builder.getContext();
|
|
auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
|
|
SmallVector<AffineExpr, 2> affineExprs;
|
|
for (int i = 0; i < inputRank; i++) {
|
|
if (i != dim)
|
|
affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
|
|
}
|
|
auto reductionMap =
|
|
AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
|
|
SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
|
|
return std::make_tuple(iteratorTypes, indexingMaps);
|
|
}
|
|
|
|
// Helper function to produce a linalg.generic that computes a reduction on
|
|
// dimension \p dim with the operation type \p T.
|
|
template <typename T>
|
|
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
|
|
int64_t dim) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] =
|
|
computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
|
|
assert(indexingMaps.size() == 2 &&
|
|
"We should have two maps: 1 for the input, 1 for the output");
|
|
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
|
|
|
auto genericOp = builder.create<linalg::GenericOp>(
|
|
loc, output.getType(), input, output, indexingMaps, iteratorTypes,
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value result = b.create<T>(loc, args[0], args[1]);
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
|
|
/// Produce a linalg generic that computes the second step of the softmax
|
|
/// decomposition: res = exp(input - max), where \p max is the max of \p input
|
|
/// on dimension \p dim.
|
|
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
|
|
Value max, Value output, int64_t dim) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
|
builder, inputRank, dim, /*allParallel=*/true);
|
|
assert(indexingMaps.size() == 2 && "We should have one map for each input");
|
|
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
|
// Add the affine map for the output argument.
|
|
indexingMaps.push_back(indexingMaps[0]);
|
|
auto genericOp = builder.create<linalg::GenericOp>(
|
|
loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
|
|
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
|
|
Value result = b.create<math::ExpOp>(loc, diff);
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
|
|
/// Produce a linalg generic that computes the final step of the softmax
|
|
/// decomposition.
|
|
/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
|
|
/// yield n / d
|
|
/// }
|
|
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
|
|
Value denominator, Value output, int64_t dim) {
|
|
auto inputType = cast<ShapedType>(numerator.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
|
builder, inputRank, dim, /*allParallel=*/true);
|
|
assert(indexingMaps.size() == 2 &&
|
|
"We should have one map for each input (2)");
|
|
assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
|
|
// Add the affine map for the output tensor.
|
|
indexingMaps.push_back(indexingMaps[0]);
|
|
auto genericOp = builder.create<linalg::GenericOp>(
|
|
loc, numerator.getType(), ValueRange{numerator, denominator}, output,
|
|
indexingMaps, iteratorTypes,
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
// @} End helper functions for softmax decomposition.
|
|
|
|
/// Given an N-dimensional tensor x, this method converts
|
|
/// softmax(x) to the following sequence of operations:
|
|
///
|
|
/// 1. Compute the max of x along dimension d. This results
|
|
/// in a N-1 dimensional tensor m.
|
|
/// m = max(x, dim = d)
|
|
///
|
|
/// 2. Subtract a broadcasted m from x and exponentiate. This results in
|
|
/// a N dimensional tensor z.
|
|
/// z = exp(x - m)
|
|
///
|
|
/// 3. Compute the sum of z along dimension d. This results in
|
|
/// a N-1 dimensional tensor l.
|
|
/// l = sum(z, dim = d)
|
|
///
|
|
/// 4. Divide z and l. This gives the N-dimensional softmax.
|
|
/// softmax = z / l
|
|
///
|
|
FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
|
|
OpBuilder::InsertionGuard guard(b);
|
|
b.setInsertionPoint(*this);
|
|
Location loc = getLoc();
|
|
Value input = getInput();
|
|
ShapedType inputType = getInputOperandType();
|
|
Type elementType = inputType.getElementType();
|
|
int64_t reductionDim = getDimension();
|
|
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
|
|
Value output = getOutput();
|
|
dims.erase(dims.begin() + reductionDim);
|
|
// Step 1: Compute max along dim.
|
|
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
|
|
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
|
|
elementType, b, loc,
|
|
/*useOnlyFiniteValue=*/true);
|
|
Value neutralForMaxFInit =
|
|
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
|
|
.result();
|
|
Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
|
|
reductionDim);
|
|
|
|
// Step 2: Subtract max from input and exponentiate.
|
|
Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
|
|
|
|
// Step 3: Compute sum along dim.
|
|
Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
|
|
b, loc, /*useOnlyFiniteValue=*/true);
|
|
Value zeroInit =
|
|
b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
|
|
Value denominator =
|
|
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
|
|
|
|
// Step 4: Compute softmax.
|
|
Value result =
|
|
buildDivOp(b, loc, numerator, denominator, output, reductionDim);
|
|
return SmallVector<Value>{result};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinalgDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LinalgDialect::getCanonicalizationPatterns(
|
|
RewritePatternSet &results) const {
|
|
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
|
|
InferStaticShapeOfOperands>(getContext());
|
|
}
|
|
|
|
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
|
|
Attribute value, Type type,
|
|
Location loc) {
|
|
return arith::ConstantOp::materialize(builder, value, type, loc);
|
|
}
|