Add support for load/store with chunk_size, which requires special consideration for the operand blocking since offests and masks are n-D and tensor are n+1-D. Support operations including create_tdesc, update_tdesc, load, store, and prefetch. --------- Co-authored-by: Adam Siemieniuk <adam.siemieniuk@intel.com>
700 lines
26 KiB
C++
700 lines
26 KiB
C++
//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains patterns for unrolling XeGPU operations. It follows a
|
|
// similar concept and design as vector unroll patterns, serving as a complement
|
|
// to them.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <numeric>
|
|
|
|
namespace mlir {
|
|
namespace xegpu {
|
|
#define GEN_PASS_DEF_XEGPUUNROLL
|
|
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
|
|
} // namespace xegpu
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "xegpu-unroll"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
|
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
template <typename SourceOp>
|
|
struct UnrollPattern : public OpRewritePattern<SourceOp> {
|
|
UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<SourceOp>(context, benefit), options(options) {}
|
|
|
|
protected:
|
|
/// Return the target shape for the given `op`. Return std::nullopt if the
|
|
/// op shouldn't be or cannot be unrolled.
|
|
std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
|
|
LDBG("");
|
|
LDBG("Get unroll shape for: " << *op);
|
|
|
|
if (options.filterConstraint && failed(options.filterConstraint(op))) {
|
|
LDBG("--no filter constraint -> BAIL");
|
|
return std::nullopt;
|
|
}
|
|
|
|
assert(options.nativeShape &&
|
|
"expects the native shape for native shape call back function.");
|
|
auto nativeShape = options.nativeShape(op);
|
|
return nativeShape;
|
|
}
|
|
|
|
SmallVector<Type> getUnrolledTypes(ShapedType type,
|
|
ArrayRef<int64_t> tileShape) const {
|
|
return options.getUnrolledTypes(type, tileShape);
|
|
}
|
|
|
|
/// Emulate the the unpack behavior using insert_strided_slice for VectorType
|
|
/// values and unrealized_conversion_cast for TensorDescType values.
|
|
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
|
|
Location loc, PatternRewriter &rewriter) const {
|
|
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
|
|
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
|
|
"Expecting blockSize size to match the rank of destTy.");
|
|
auto shape = vecTy.getShape();
|
|
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
|
|
}
|
|
|
|
if (isa<xegpu::TensorDescType>(destTy)) {
|
|
auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
|
|
rewriter.getUnitAttr());
|
|
auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
|
|
rewriter.getDenseI64ArrayAttr(blockSize));
|
|
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
|
|
loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr}));
|
|
return castOp.getResult(0);
|
|
}
|
|
|
|
llvm_unreachable("Unexpected destTy.");
|
|
return Value();
|
|
}
|
|
|
|
/// Emulate the the pack behavior using extract_strided_slice for VectorType
|
|
/// values and unrealized_conversion_cast for TensorDescType values.
|
|
SmallVector<Value> pack(Value src, TypeRange destTypes,
|
|
ArrayRef<int64_t> blockSize, Location loc,
|
|
PatternRewriter &rewriter) const {
|
|
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
|
|
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
|
|
"Expecting blockSize size to match the rank of src.");
|
|
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
|
|
blockSize);
|
|
}
|
|
|
|
if (isa<xegpu::TensorDescType>(src.getType())) {
|
|
auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
|
|
rewriter.getUnitAttr());
|
|
auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
|
|
rewriter.getDenseI64ArrayAttr(blockSize));
|
|
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
|
|
loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr}));
|
|
return castOp.getResults();
|
|
}
|
|
|
|
llvm_unreachable("Unexpected src type.");
|
|
return SmallVector<Value>();
|
|
}
|
|
|
|
private:
|
|
const char *const packAttrName = "__xegpu_blocking_pack__";
|
|
const char *const unpackAttrName = "__xegpu_blocking_unpack__";
|
|
const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
|
|
|
|
xegpu::UnrollOptions options;
|
|
};
|
|
|
|
struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
|
|
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getType();
|
|
int64_t rank = tdescTy.getRank();
|
|
ArrayRef<int64_t> shape = tdescTy.getShape();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
|
|
|
|
auto addi = [&](OpFoldResult a, int64_t b) -> Value {
|
|
std::optional<int64_t> maybeInt = getConstantIntValue(a);
|
|
if (maybeInt) {
|
|
return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
|
|
} else {
|
|
auto aV = llvm::cast<Value>(a);
|
|
auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
|
|
return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
|
|
}
|
|
};
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
|
|
|
|
// For n-D memrefs where n > rank, we need to handle the last `rank`
|
|
// dimensions only, and keep the first `n-rank` dimensions as is.
|
|
SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
|
|
llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
|
|
auto validIdxes =
|
|
llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
|
|
|
|
SmallVector<Value> newOps;
|
|
for (SmallVector<int64_t> offsets :
|
|
StaticTileOffsetRange(shape, *targetShape)) {
|
|
|
|
for (auto [idx, oldOff, offset] :
|
|
llvm::zip(validIdxes, oldOffsets, offsets))
|
|
mixedOffsets[idx] = addi(oldOff, offset);
|
|
|
|
auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
|
|
loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
|
|
op.getMixedStrides());
|
|
newOps.push_back(newOp);
|
|
}
|
|
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
|
|
using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto t : convertedTdesc) {
|
|
auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
|
|
loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
|
|
newOps.push_back(newOp);
|
|
}
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
|
|
using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
for (auto t : convertedTdesc)
|
|
rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
|
|
using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = op.getType();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
Type elemTy = tdescTy.getElementType();
|
|
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto t : convertedTdescs) {
|
|
auto newOp =
|
|
rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
|
|
newOps.push_back(newOp);
|
|
}
|
|
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
|
|
using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = op.getValueType();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedValTypes =
|
|
getUnrolledTypes(valueTy, *targetShape);
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
|
|
SmallVector<Value> convertedValues =
|
|
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
|
|
rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
|
|
using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::DpasOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
// expecting every operands is a 2D Vector
|
|
if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
|
|
auto vecTy = dyn_cast<VectorType>(type);
|
|
return !vecTy || vecTy.getRank() != 2;
|
|
}))
|
|
return failure();
|
|
|
|
// A vector of 3 elements should be returned, representing M, K, N
|
|
// respectively.
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape || targetShape->size() != 3)
|
|
return failure();
|
|
auto M = (*targetShape)[0];
|
|
auto K = (*targetShape)[1];
|
|
auto N = (*targetShape)[2];
|
|
|
|
int64_t aBlockSize[2] = {M, K};
|
|
int64_t bBlockSize[2] = {K, N};
|
|
int64_t cBlockSize[2] = {M, N};
|
|
|
|
auto packWrapper = [&](TypedValue<VectorType> val,
|
|
ArrayRef<int64_t> blockSize) {
|
|
VectorType type = val.getType();
|
|
std::optional<SmallVector<int64_t>> grids =
|
|
computeShapeRatio(type.getShape(), blockSize);
|
|
assert(grids && "Expecting grids to be computed.");
|
|
auto numNewOps = computeProduct(*grids);
|
|
if (numNewOps == 1)
|
|
return SmallVector<Value>({val});
|
|
VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
|
|
SmallVector<Type> convertedTypes(numNewOps, newVecTy);
|
|
SmallVector<Value> values =
|
|
pack(val, convertedTypes, blockSize, loc, rewriter);
|
|
return values;
|
|
};
|
|
|
|
auto a = op.getLhs();
|
|
auto b = op.getRhs();
|
|
auto c = op.getAcc();
|
|
|
|
auto aShape = a.getType().getShape();
|
|
auto bShape = b.getType().getShape();
|
|
|
|
SmallVector<Value> aVals, bVals, cVals;
|
|
aVals = packWrapper(a, aBlockSize);
|
|
bVals = packWrapper(b, bBlockSize);
|
|
|
|
if (c)
|
|
cVals = packWrapper(c, cBlockSize);
|
|
|
|
// Skip the operation if every operand has an invalid blocking size (empty)
|
|
// or if the original shape matches the blocking size (size == 1).
|
|
auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
|
|
: SmallVector<ValueRange>({aVals, bVals});
|
|
if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
|
|
llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
|
|
return failure();
|
|
|
|
VectorType resultTy = op.getResult().getType();
|
|
auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
|
|
|
|
int64_t mIters = aShape[0] / M;
|
|
int64_t kIters = aShape[1] / K;
|
|
int64_t nIters = bShape[1] / N;
|
|
|
|
SmallVector<Value> newOps;
|
|
for (int64_t i = 0; i < mIters; ++i) {
|
|
for (int64_t j = 0; j < nIters; ++j) {
|
|
Value tmpC;
|
|
if (c)
|
|
tmpC = cVals[i * nIters + j]; // init with acc
|
|
|
|
for (int64_t k = 0; k < kIters; ++k) {
|
|
Value aVec = aVals[i * kIters + k];
|
|
Value bVec = bVals[k * nIters + j];
|
|
SmallVector<Value> operands({aVec, bVec});
|
|
if (tmpC)
|
|
operands.push_back(tmpC);
|
|
|
|
tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
|
|
op->getAttrs());
|
|
}
|
|
newOps.push_back(tmpC);
|
|
}
|
|
}
|
|
Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
|
|
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getType();
|
|
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
|
|
VectorType indiceVecTy = indiceVec.getType();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetIndiceShape(*targetShape);
|
|
int64_t originalChunkSize = tdescTy.getChunkSize();
|
|
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
|
|
if (originalChunkSize > 1)
|
|
targetIndiceShape.pop_back();
|
|
|
|
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
|
|
SmallVector<Type> convertedIndiceTypes =
|
|
getUnrolledTypes(indiceVecTy, targetIndiceShape);
|
|
SmallVector<Value> convertedIndiceVec =
|
|
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
|
|
|
|
SmallVector<Value> newOps;
|
|
|
|
// More indices is need when chunkSize > 1. Since a big load from one
|
|
// address could be break into multiple small loads.
|
|
if (originalChunkSize > 1) {
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
|
|
for (auto [indice, indiceType] :
|
|
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i) {
|
|
// Compute the offset
|
|
Value inc = rewriter.create<arith::ConstantIndexOp>(
|
|
loc, i * blockedChunkSize);
|
|
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
|
|
Value offsetIndice =
|
|
rewriter.create<arith::AddIOp>(loc, indice, incVec);
|
|
|
|
auto newOp = rewriter.create<xegpu::CreateDescOp>(
|
|
loc, newTdescTy, op.getSource(), offsetIndice);
|
|
|
|
newOps.push_back(newOp);
|
|
}
|
|
}
|
|
} else {
|
|
for (auto indice : convertedIndiceVec) {
|
|
auto newOp = rewriter.create<xegpu::CreateDescOp>(
|
|
loc, newTdescTy, op.getSource(), indice);
|
|
newOps.push_back(newOp);
|
|
}
|
|
}
|
|
|
|
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
|
|
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetMaskShape(*targetShape);
|
|
int64_t originalChunkSize = tdescTy.getChunkSize();
|
|
|
|
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
|
|
|
Type elemTy = tdescTy.getElementType();
|
|
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Type> convertedMaskTypes;
|
|
SmallVector<Value> convertedMasks;
|
|
|
|
if (originalChunkSize > 1) {
|
|
targetMaskShape.pop_back();
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
SmallVector<Value> convertedMasks1D = pack(
|
|
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
|
|
for (auto mask : convertedMasks1D) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i)
|
|
convertedMasks.push_back(mask);
|
|
}
|
|
// This is to handle the transpose effect when chunkSize > 1.
|
|
std::swap((*targetShape)[0], (*targetShape)[1]);
|
|
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
} else {
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
|
|
loc, rewriter);
|
|
}
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
|
|
auto newOp = rewriter.create<xegpu::LoadGatherOp>(
|
|
loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
newOps.push_back(newOp);
|
|
}
|
|
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
|
|
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
for (auto t : convertedTdesc)
|
|
rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
|
|
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetIndiceShape(*targetShape);
|
|
int64_t originalChunkSize = tdescTy.getChunkSize();
|
|
|
|
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Type> convertedMaskTypes;
|
|
SmallVector<Value> convertedMasks;
|
|
|
|
if (originalChunkSize > 1) {
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
|
|
SmallVector<Value> convertedMasks1D = pack(
|
|
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
|
|
|
|
for (auto mask : convertedMasks1D) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i) {
|
|
convertedMasks.push_back(mask);
|
|
}
|
|
}
|
|
// This is to handle the transpose effect when chunkSize > 1.
|
|
std::swap((*targetShape)[0], (*targetShape)[1]);
|
|
|
|
} else {
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
|
|
convertedMasks =
|
|
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
|
|
}
|
|
|
|
SmallVector<Type> convertedValTypes =
|
|
getUnrolledTypes(valueTy, *targetShape);
|
|
SmallVector<Value> convertedValues =
|
|
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
|
|
|
for (size_t i = 0; i < convertedValues.size(); ++i) {
|
|
Value v = convertedValues[i];
|
|
Value t = convertedTdescs[i];
|
|
Value m = op.getMask() ? convertedMasks[i] : nullptr;
|
|
rewriter.create<xegpu::StoreScatterOp>(
|
|
loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
|
|
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
if (tdescTy.getRank() > 2)
|
|
return failure();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
|
|
VectorType offsetVecTy = offsetVec.getType();
|
|
SmallVector<Type> convertedOffsetTypes;
|
|
SmallVector<Value> convertedOffsetVec;
|
|
SmallVector<Value> newOps;
|
|
int64_t originalChunkSize = tdescTy.getChunkSize();
|
|
if (originalChunkSize > 1) {
|
|
SmallVector<int64_t> shape1D(targetShape->begin(),
|
|
targetShape->end() - 1);
|
|
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
|
|
SmallVector<Value> convertedOffsetVec1D =
|
|
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
|
|
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
|
|
for (auto offset : convertedOffsetVec1D) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i) {
|
|
convertedOffsetVec.push_back(offset);
|
|
}
|
|
}
|
|
|
|
} else {
|
|
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
|
|
convertedOffsetVec =
|
|
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
|
|
}
|
|
|
|
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
|
|
auto newOp =
|
|
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
|
|
newOps.push_back(newOp);
|
|
}
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::xegpu::populateXeGPUUnrollPatterns(
|
|
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
|
|
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
|
|
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
|
|
UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
|
|
UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
|
|
options);
|
|
}
|