Files
clang-p2996/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
Benjamin Maxwell c42512436b [mlir][ArmSME] Rename slice move operations to insert/extract_tile_slice (#106755)
This renames:

- `arm_sme.move_tile_slice_to_vector` to `arm_sme.extract_tile_slice`
- `arm_sme.move_vector_to_tile_slice` to `arm_sme.insert_tile_slice`

The new names are more consistent with the rest of MLIR and should be
easier to understand. The current names (to me personally) are hard to
parse and easy to mix up when skimming through code.

Additionally, the syntax for `insert_tile_slice` has changed from:

```mlir
%4 = arm_sme.insert_tile_slice %0, %1, %2
  : vector<[16]xi8> into vector<[16]x[16]xi8>
```

To:

```mlir
%4 = arm_sme.insert_tile_slice %0, %1[%2]
  : vector<[16]xi8> into vector<[16]x[16]xi8>
```

This is for consistency with `extract_tile_slice`, but also helps with
readability as it makes it clear which operand is the index.
2024-09-02 11:12:40 +01:00

126 lines
4.7 KiB
C++

//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
#define DEBUG_TYPE "arith-to-arm-sme"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Conversion helpers
//===----------------------------------------------------------------------===//
/// Returns true if 'val' is a splat of zero, false otherwise.
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
if (llvm::isa<IntegerType>(elemType))
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
return false;
}
namespace {
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
/// Conversion pattern for dense arith.constant.
struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
PatternRewriter &rewriter) const final {
auto tileType = dyn_cast<VectorType>(constantOp.getType());
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
if (!denseAttr || !denseAttr.isSplat())
return failure();
auto tileElementType = tileType.getElementType();
// Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
if (isSplatZero(tileElementType, denseAttr)) {
rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
return success();
}
// Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
// ops that broadcast the constant to each tile slice.
auto loc = constantOp.getLoc();
// To fill a tile with a constant, we create a 1-D splat of the constant,
// then move that into each tile slice (the largest unit we can set at once,
// outside of operations like the outerproduct).
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to write vector to tile
// slice.
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
auto forOp = mlir::arm_sme::createLoopOverTileSlices(
rewriter, loc, initTile, makeLoopBody);
rewriter.replaceOp(constantOp, forOp.getResult(0));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::arith::populateArithToArmSMEConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Pass definition
//===----------------------------------------------------------------------===//
namespace {
struct ArithToArmSMEConversionPass final
: impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
using impl::ArithToArmSMEConversionPassBase<
ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
arith::populateArithToArmSMEConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace