Files
clang-p2996/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Tres Popp 5550c82189 [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Caveats include:
- This clang-tidy script probably has more problems.
- This only touches C++ code, so nothing that is being generated.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This first patch was created with the following steps. The intention is
to only do automated changes at first, so I waste less time if it's
reverted, and so the first mass change is more clear as an example to
other teams that will need to follow similar steps.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.
4. Some changes have been deleted for the following reasons:
   - Some files had a variable also named cast
   - Some files had not included a header file that defines the cast
     functions
   - Some files are definitions of the classes that have the casting
     methods, so the code still refers to the method instead of the
     function without adding a prefix or removing the method declaration
     at the same time.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc

git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\
            mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\
            mlir/lib/**/IR/\
            mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\
            mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\
            mlir/test/lib/Dialect/Test/TestTypes.cpp\
            mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\
            mlir/test/lib/Dialect/Test/TestAttributes.cpp\
            mlir/unittests/TableGen/EnumsGenTest.cpp\
            mlir/test/python/lib/PythonTestCAPI.cpp\
            mlir/include/mlir/IR/
```

Differential Revision: https://reviews.llvm.org/D150123
2023-05-12 11:21:25 +02:00

740 lines
30 KiB
C++

//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Returns the offset of the value in `targetBits` representation.
///
/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
/// It's assumed to be non-negative.
///
/// When accessing an element in the array treating as having elements of
/// `targetBits`, multiple values are loaded in the same time. The method
/// returns the offset where the `srcIdx` locates in the value. For example, if
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
/// element has 8 bits.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
int targetBits, OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
IntegerType targetType = builder.getIntegerType(targetBits);
IntegerAttr idxAttr =
builder.getIntegerAttr(targetType, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
auto srcBitsValue =
builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
}
/// Returns an adjusted spirv::AccessChainOp. Based on the
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
/// supported. During conversion if a memref of an unsupported type is used,
/// load/stores to this memref need to be modified to use a supported higher
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
/// 1D array (spirv.array or spirv.rt_array), the last index is modified to load
/// the bits needed. The extraction of the actual bits needed are handled
/// separately. Note that this only works for a 1-D tensor.
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
spirv::AccessChainOp op,
int sourceBits, int targetBits,
OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
const auto loc = op.getLoc();
IntegerType targetType = builder.getIntegerType(targetBits);
IntegerAttr attr =
builder.getIntegerAttr(targetType, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
auto lastDim = op->getOperand(op.getNumOperands() - 1);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
/// Returns the shifted `targetBits`-bit value with the given offset.
static Value shiftValue(Location loc, Value value, Value offset, Value mask,
int targetBits, OpBuilder &builder) {
Type targetType = builder.getIntegerType(targetBits);
Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
offset);
}
/// Returns true if the allocations of memref `type` generated from `allocOp`
/// can be lowered to SPIR-V.
static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
return false;
} else if (isa<memref::AllocaOp>(allocOp)) {
auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!sc || sc.getValue() != spirv::StorageClass::Function)
return false;
} else {
return false;
}
// Currently only support static shape and int or float or vector of int or
// float element type.
if (!type.hasStaticShape())
return false;
Type elementType = type.getElementType();
if (auto vecType = dyn_cast<VectorType>(elementType))
elementType = vecType.getElementType();
return elementType.isIntOrFloat();
}
/// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns std::nullopt on failure.
static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
switch (sc.getValue()) {
case spirv::StorageClass::StorageBuffer:
return spirv::Scope::Device;
case spirv::StorageClass::Workgroup:
return spirv::Scope::Workgroup;
default:
break;
}
return {};
}
/// Casts the given `srcInt` into a boolean value.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
if (srcInt.getType().isInteger(1))
return srcInt;
auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
return builder.create<spirv::IEqualOp>(loc, srcInt, one);
}
/// Casts the given `srcBool` into an integer of `dstType`.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
OpBuilder &builder) {
assert(srcBool.getType().isInteger(1));
if (dstType.isInteger(1))
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.
namespace {
/// Converts memref.alloca to SPIR-V Function variables.
class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
public:
using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spirv.module scope since it
/// wil ladd global variables into the spirv.module.
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
public:
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
class AtomicRMWOpPattern final
: public OpConversionPattern<memref::AtomicRMWOp> {
public:
using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
public:
using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.load to spirv.Load + spirv.AccessChain.
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.store to spirv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
class MemorySpaceCastOpPattern final
: public OpConversionPattern<memref::MemorySpaceCastOp> {
public:
using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts memref.store to spirv.Store.
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
LogicalResult
AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = allocaOp.getType();
if (!isAllocationSupported(allocaOp, allocType))
return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
spirv::StorageClass::Function,
/*initializer=*/nullptr);
return success();
}
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = operation.getType();
if (!isAllocationSupported(operation, allocType))
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
// Insert spirv.GlobalVariable for this allocation.
Operation *parent =
SymbolTable::getNearestSymbolTable(operation->getParentOp());
if (!parent)
return failure();
Location loc = operation.getLoc();
spirv::GlobalVariableOp varOp;
{
OpBuilder::InsertionGuard guard(rewriter);
Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(&entryBlock);
auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
std::string varName =
std::string("__workgroup_mem__") +
std::to_string(std::distance(varOps.begin(), varOps.end()));
varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
/*initializer=*/nullptr);
}
// Get pointer to global variable at the current scope.
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
return success();
}
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
LogicalResult
AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (isa<FloatType>(atomicOp.getType()))
return rewriter.notifyMatchFailure(atomicOp,
"unimplemented floating-point case");
auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
return rewriter.notifyMatchFailure(atomicOp,
"unsupported memref memory space");
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Type resultType = typeConverter.convertType(atomicOp.getType());
if (!resultType)
return rewriter.notifyMatchFailure(atomicOp,
"failed to convert result type");
auto loc = atomicOp.getLoc();
Value ptr =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
if (!ptr)
return failure();
#define ATOMIC_CASE(kind, spirvOp) \
case arith::AtomicRMWKind::kind: \
rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
atomicOp, resultType, ptr, *scope, \
spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
break
switch (atomicOp.getKind()) {
ATOMIC_CASE(addi, AtomicIAddOp);
ATOMIC_CASE(maxs, AtomicSMaxOp);
ATOMIC_CASE(maxu, AtomicUMaxOp);
ATOMIC_CASE(mins, AtomicSMinOp);
ATOMIC_CASE(minu, AtomicUMinOp);
ATOMIC_CASE(ori, AtomicOrOp);
ATOMIC_CASE(andi, AtomicAndOp);
default:
return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
}
#undef ATOMIC_CASE
return success();
}
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
LogicalResult
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
if (!isAllocationSupported(operation, deallocType))
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
rewriter.eraseOp(operation);
return success();
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = loadOp.getLoc();
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (!memrefType.getElementType().isSignlessInteger())
return failure();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
if (!pointerType)
return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
cast<spirv::StructType>(pointeeType).getElementType(0);
if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
dstType = arrayType.getElementType();
else
dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
}
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
// If the rewrited load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
return success();
}
// Bitcasting is currently unsupported for Kernel capability /
// spirv.PtrAccessChain.
if (typeConverter.allows(spirv::Capability::Kernel))
return failure();
auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
if (!accessChainOp)
return failure();
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
// still returns a linearized accessing. If the accessing is not linearized,
// there will be offset issues.
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
loc, dstType, adjustedPtr,
loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
loadOp->getAttrOfType<IntegerAttr>("alignment"));
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);
// Apply the mask to extract corresponding bits.
Value mask = rewriter.create<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
// handle the casting.
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
shiftValue);
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
shiftValue);
if (isBool) {
dstType = typeConverter.convertType(loadOp.getType());
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
} else if (result.getType().getIntOrFloatBitWidth() !=
static_cast<unsigned>(dstBits)) {
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
}
rewriter.replaceOp(loadOp, result);
assert(accessChainOp.use_empty());
rewriter.eraseOp(accessChainOp);
return success();
}
LogicalResult
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
return failure();
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return success();
}
LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
if (!memrefType.getElementType().isSignlessInteger())
return failure();
auto loc = storeOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
if (!pointerType)
return rewriter.notifyMatchFailure(storeOp,
"failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
cast<spirv::StructType>(pointeeType).getElementType(0);
if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
dstType = arrayType.getElementType();
else
dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
}
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
return success();
}
// Bitcasting is currently unsupported for Kernel capability /
// spirv.PtrAccessChain.
if (typeConverter.allows(spirv::Capability::Kernel))
return failure();
auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
if (!accessChainOp)
return failure();
// Since there are multi threads in the processing, the emulation will be done
// with atomic operations. E.g., if the storing value is i8, rewrite the
// StoreOp to
// 1) load a 32-bit integer
// 2) clear 8 bits in the loading value
// 3) store 32-bit value back
// 4) load a 32-bit integer
// 5) modify 8 bits in the loading value
// 6) store 32-bit value back
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
// 4 to step 6 are done by AtomicOr as another atomic step.
assert(accessChainOp.getIndices().size() == 2);
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
Value mask = rewriter.create<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
Value clearBitsMask =
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
return failure();
Value result = rewriter.create<spirv::AtomicAndOp>(
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
clearBitsMask);
result = rewriter.create<spirv::AtomicOrOp>(
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
storeVal);
// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
// doesn't work because it only accepts that the numbers of result are the
// same.
rewriter.eraseOp(storeOp);
assert(accessChainOp.use_empty());
rewriter.eraseOp(accessChainOp);
return success();
}
//===----------------------------------------------------------------------===//
// MemorySpaceCastOp
//===----------------------------------------------------------------------===//
LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = addrCastOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
if (!typeConverter.allows(spirv::Capability::Kernel))
return rewriter.notifyMatchFailure(
loc, "address space casts require kernel capability");
auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
if (!sourceType)
return rewriter.notifyMatchFailure(
loc, "SPIR-V lowering requires ranked memref types");
auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
auto sourceStorageClassAttr =
dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
if (!sourceStorageClassAttr)
return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
diag << "source address space " << sourceType.getMemorySpace()
<< " must be a SPIR-V storage class";
});
auto resultStorageClassAttr =
dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
if (!resultStorageClassAttr)
return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
diag << "result address space " << resultType.getMemorySpace()
<< " must be a SPIR-V storage class";
});
spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
Value result = adaptor.getSource();
Type resultPtrType = typeConverter.convertType(resultType);
Type genericPtrType = resultPtrType;
// SPIR-V doesn't have a general address space cast operation. Instead, it has
// conversions to and from generic pointers. To implement the general case,
// we use specific-to-generic conversions when the source class is not
// generic. Then when the result storage class is not generic, we convert the
// generic pointer (either the input on ar intermediate result) to theat
// class. This also means that we'll need the intermediate generic pointer
// type if neither the source or destination have it.
if (sourceSc != spirv::StorageClass::Generic &&
resultSc != spirv::StorageClass::Generic) {
Type intermediateType =
MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
sourceType.getLayout(),
rewriter.getAttr<spirv::StorageClassAttr>(
spirv::StorageClass::Generic));
genericPtrType = typeConverter.convertType(intermediateType);
}
if (sourceSc != spirv::StorageClass::Generic) {
result =
rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
}
if (resultSc != spirv::StorageClass::Generic) {
result =
rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
}
rewriter.replaceOp(addrCastOp, result);
return success();
}
LogicalResult
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto storePtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), storeOp.getLoc(), rewriter);
if (!storePtr)
return failure();
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
adaptor.getValue());
return success();
}
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern>(
typeConverter, patterns.getContext());
}
} // namespace mlir