Files
clang-p2996/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
Nicolas Vasilache 75394e1301 [mlir][EDSC] Almost NFC - Refactor and untangle EDSC dependencies
This CL refactors EDSCs to layer them better and break unnecessary
dependencies. After this refactoring, the top-level EDSC target only
depends on IR but not on Dialects anymore and each dialect has its
own EDSC directory.

This simplifies the layering and breaks cyclic dependencies.
In particular, the declarative builder + folder are made explicit and
are now confined to Linalg.

As the refactoring occurred, certain classes and abstractions that were not
paying for themselves have been removed.

Differential Revision: https://reviews.llvm.org/D74302
2020-02-10 12:10:41 -05:00

93 lines
4.0 KiB
C++

//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
using namespace mlir;
using namespace mlir::edsc;
mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps) {
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
loops.emplace_back(makeParallelLoopBuilder(ivs, lbs, ubs, steps));
}
void mlir::edsc::ParallelLoopNestBuilder::operator()(
function_ref<void(void)> fun) {
if (fun)
fun();
// Iterate on the calling operator() on all the loops in the nest.
// The iteration order is from innermost to outermost because enter/exit needs
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
// occurs on calling operator()). The asymmetry is required for properly
// nesting imperfectly nested regions (see LoopBuilder::operator()).
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit)
(*lit)();
}
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
ArrayRef<ValueHandle> steps) {
assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match");
assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match");
assert(ivs.size() == steps.size() &&
"expected size of ivs and steps to match");
loops.reserve(ivs.size());
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it),
std::get<2>(it), std::get<3>(it)));
}
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
void mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
std::function<void(void)> fun) {
if (fun)
fun();
for (auto &lit : reverse(loops))
lit({});
}
LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
ArrayRef<ValueHandle> steps) {
LoopBuilder result;
auto opHandle = OperationHandle::create<loop::ParallelOp>(
SmallVector<Value, 4>(lbHandles.begin(), lbHandles.end()),
SmallVector<Value, 4>(ubHandles.begin(), ubHandles.end()),
SmallVector<Value, 4>(steps.begin(), steps.end()));
loop::ParallelOp parallelOp =
cast<loop::ParallelOp>(*opHandle.getOperation());
for (size_t i = 0, e = ivs.size(); i < e; ++i)
*ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i));
result.enter(parallelOp.getBody(), /*prev=*/1);
return result;
}
mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(ValueHandle *iv,
ValueHandle lbHandle,
ValueHandle ubHandle,
ValueHandle stepHandle) {
mlir::edsc::LoopBuilder result;
auto forOp =
OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle);
*iv = ValueHandle(forOp.getInductionVar());
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
result.enter(body, /*prev=*/1);
return result;
}