Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
Tres Popp c2c83e97c3 Revert "Revert "Reorder MLIRContext location in BuiltinAttributes.h""
This reverts commit 511dd4f438 along with
a couple fixes.

Original message:
Now the context is the first, rather than the last input.

This better matches the rest of the infrastructure and makes
it easier to move these types to being declaratively specified.

Phabricator: https://reviews.llvm.org/D96111
2021-02-08 10:39:58 +01:00

86 lines
3.2 KiB
C++

//===- Interchange.cpp - Linalg interchange transformation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the linalg interchange transformation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
#define DEBUG_TYPE "linalg-interchange"
using namespace mlir;
using namespace mlir::linalg;
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
Operation *op, ArrayRef<unsigned> interchangeVector) {
if (interchangeVector.empty())
return failure();
// Transformation applies to generic ops only.
if (!isa<GenericOp, IndexedGenericOp>(op))
return failure();
LinalgOp linOp = cast<LinalgOp>(op);
// Transformation applies to buffers only.
if (!linOp.hasBufferSemantics())
return failure();
// Permutation must be applicable.
if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
return failure();
// Permutation map must be invertible.
if (!inversePermutation(
AffineMap::getPermutationMap(interchangeVector, op->getContext())))
return failure();
return success();
}
LinalgOp mlir::linalg::interchange(LinalgOp op,
ArrayRef<unsigned> interchangeVector) {
if (interchangeVector.empty())
return op;
MLIRContext *context = op.getContext();
auto permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
assert(permutationMap && "expected permutation to be invertible");
SmallVector<Attribute, 4> newIndexingMaps;
auto indexingMaps = op.indexing_maps().getValue();
for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(AffineMapAttr::get(m));
}
auto itTypes = op.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
itTypesVector.push_back(itTypes[i]);
applyPermutationToVector(itTypesVector, interchangeVector);
op->setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(context, newIndexingMaps));
op->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
return op;
}