//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for testing Linalg hoisting functions. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::linalg; template static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) { std::string d(1, dim); StringAttr attr = b.getStringAttr(d); Type indexType = b.getIndexType(); ProcInfo procInfo = {b.create(loc, indexType, attr), b.create(loc, indexType, attr)}; return procInfo; } static LinalgLoopDistributionOptions getDistributionOptions() { LinalgLoopDistributionOptions opts; opts.procInfoMap.insert(std::make_pair("block_x", getGpuBlockInfo<'x'>)); opts.procInfoMap.insert(std::make_pair("block_y", getGpuBlockInfo<'y'>)); return opts; } namespace { struct TestLinalgDistribution : public PassWrapper { StringRef getArgument() const final { return "test-linalg-distribution"; } StringRef getDescription() const final { return "Test Linalg distribution."; } TestLinalgDistribution() = default; TestLinalgDistribution(const TestLinalgDistribution &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnFunction() override; }; } // namespace void TestLinalgDistribution::runOnFunction() { auto funcOp = getFunction(); OwningRewritePatternList distributeTiledLoopsPatterns(&getContext()); populateLinalgDistributeTiledLoopPattern( distributeTiledLoopsPatterns, getDistributionOptions(), LinalgTransformationFilter( ArrayRef{}, {Identifier::get("distributed", funcOp.getContext())}) .addFilter([](Operation *op) { return success(!op->getParentOfType()); })); (void)applyPatternsAndFoldGreedily(funcOp, std::move(distributeTiledLoopsPatterns)); // Ensure we drop the marker in the end. funcOp.walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); } namespace mlir { namespace test { void registerTestLinalgDistribution() { PassRegistration(); } } // namespace test } // namespace mlir