This PR adds support for shape casting from and to 0-D vectors. Reviewed By: nicolasvasilache, hanchung, awarzynski Differential Revision: https://reviews.llvm.org/D151851
196 lines
7.8 KiB
C++
196 lines
7.8 KiB
C++
//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements target-independent rewrites and utilities to lower the
|
|
// 'vector.shape_cast' operation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Interfaces/VectorInterfaces.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#define DEBUG_TYPE "vector-shape-cast-lowering"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
|
|
/// vectors progressively on the way to target llvm.matrix intrinsics.
|
|
/// This iterates over the most major dimension of the 2-D vector and performs
|
|
/// rewrites into:
|
|
/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
|
|
class ShapeCastOp2DDownCastRewritePattern
|
|
: public OpRewritePattern<vector::ShapeCastOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sourceVectorType = op.getSourceVectorType();
|
|
auto resultVectorType = op.getResultVectorType();
|
|
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
Value desc = rewriter.create<arith::ConstantOp>(
|
|
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
|
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
|
|
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
|
|
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
|
|
desc = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, vec, desc,
|
|
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
|
|
}
|
|
rewriter.replaceOp(op, desc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
|
|
/// vectors progressively.
|
|
/// This iterates over the most major dimension of the 2-D vector and performs
|
|
/// rewrites into:
|
|
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
|
|
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
|
|
class ShapeCastOp2DUpCastRewritePattern
|
|
: public OpRewritePattern<vector::ShapeCastOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sourceVectorType = op.getSourceVectorType();
|
|
auto resultVectorType = op.getResultVectorType();
|
|
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
Value desc = rewriter.create<arith::ConstantOp>(
|
|
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
|
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
|
|
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
|
|
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
|
|
/*sizes=*/mostMinorVectorSize,
|
|
/*strides=*/1);
|
|
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
|
|
}
|
|
rewriter.replaceOp(op, desc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// We typically should not lower general shape cast operations into data
|
|
// movement instructions, since the assumption is that these casts are
|
|
// optimized away during progressive lowering. For completeness, however,
|
|
// we fall back to a reference implementation that moves all elements
|
|
// into the right place if we get here.
|
|
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
auto sourceVectorType = op.getSourceVectorType();
|
|
auto resultVectorType = op.getResultVectorType();
|
|
|
|
// Special case 2D / 1D lowerings with better implementations.
|
|
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
|
|
int64_t srcRank = sourceVectorType.getRank();
|
|
int64_t resRank = resultVectorType.getRank();
|
|
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
|
|
return failure();
|
|
|
|
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
|
|
// extract/insert chains.
|
|
// TODO: consider evolving the semantics to only allow 1D source or dest and
|
|
// drop this potentially very expensive lowering.
|
|
// Compute number of elements involved in the reshape.
|
|
int64_t numElts = 1;
|
|
for (int64_t r = 0; r < srcRank; r++)
|
|
numElts *= sourceVectorType.getDimSize(r);
|
|
// Replace with data movement operations:
|
|
// x[0,0,0] = y[0,0]
|
|
// x[0,0,1] = y[0,1]
|
|
// x[0,1,0] = y[0,2]
|
|
// etc., incrementing the two index vectors "row-major"
|
|
// within the source and result shape.
|
|
SmallVector<int64_t> srcIdx(srcRank);
|
|
SmallVector<int64_t> resIdx(resRank);
|
|
Value result = rewriter.create<arith::ConstantOp>(
|
|
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
|
for (int64_t i = 0; i < numElts; i++) {
|
|
if (i != 0) {
|
|
incIdx(srcIdx, sourceVectorType, srcRank - 1);
|
|
incIdx(resIdx, resultVectorType, resRank - 1);
|
|
}
|
|
|
|
Value extract;
|
|
if (srcRank == 0) {
|
|
// 0-D vector special case
|
|
assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
|
|
extract = rewriter.create<vector::ExtractElementOp>(
|
|
loc, op.getSourceVectorType().getElementType(), op.getSource());
|
|
} else {
|
|
extract =
|
|
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
|
|
}
|
|
|
|
if (resRank == 0) {
|
|
// 0-D vector special case
|
|
assert(resIdx.empty() && "Unexpected indices for 0-D vector");
|
|
result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
|
|
} else {
|
|
result =
|
|
rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
|
|
}
|
|
}
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
|
|
assert(0 <= r && r < tp.getRank());
|
|
if (++idx[r] == tp.getDimSize(r)) {
|
|
idx[r] = 0;
|
|
incIdx(idx, tp, r - 1);
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::vector::populateVectorShapeCastLoweringPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit benefit) {
|
|
patterns.add<ShapeCastOp2DDownCastRewritePattern,
|
|
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
|
|
patterns.getContext(), benefit);
|
|
}
|