Files
clang-p2996/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Fangrui Song cbb0981388 [mlir] llvm::Optional::value => operator*/operator->
std::optional::value() has undesired exception checking semantics and is
unavailable in older Xcode (see _LIBCPP_AVAILABILITY_BAD_OPTIONAL_ACCESS). The
call sites block std::optional migration.
2022-12-17 19:07:38 +00:00

227 lines
7.4 KiB
C++

//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the AMDGPU dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
#include <limits>
using namespace mlir;
using namespace mlir::amdgpu;
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.cpp.inc"
void AMDGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyRawBufferOp(T &op) {
MemRefType bufferType = op.getMemref().getType().template cast<MemRefType>();
if (bufferType.getMemorySpaceAsInt() != 0)
return op.emitOpError(
"Buffer ops must operate on a memref in global memory");
if (!bufferType.hasRank())
return op.emitOpError(
"Cannot meaningfully buffer_store to an unranked memref");
if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
" indices to memref");
return success();
}
LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
LogicalResult RawBufferAtomicFaddOp::verify() {
return verifyRawBufferOp(*this);
}
static Optional<uint32_t> getConstantUint32(Value v) {
APInt cst;
if (!v.getType().isInteger(32))
return std::nullopt;
if (matchPattern(v, m_ConstantInt(&cst)))
return cst.getZExtValue();
return std::nullopt;
}
template <typename OpType>
static bool staticallyOutOfBounds(OpType op) {
if (!op.getBoundsCheck())
return false;
MemRefType bufferType = op.getMemref().getType();
if (!bufferType.hasStaticShape())
return false;
int64_t offset;
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(bufferType, strides, offset)))
return false;
int64_t result = offset + op.getIndexOffset().value_or(0);
if (op.getSgprOffset()) {
Optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
if (!sgprOffset)
return false;
result += *sgprOffset;
}
if (strides.size() != op.getIndices().size())
return false;
int64_t indexVal = 0;
for (auto pair : llvm::zip(strides, op.getIndices())) {
int64_t stride = std::get<0>(pair);
Value idx = std::get<1>(pair);
Optional<uint32_t> idxVal = getConstantUint32(idx);
if (!idxVal)
return false;
indexVal += stride * *idxVal;
}
result += indexVal;
if (result > std::numeric_limits<uint32_t>::max())
// Overflow means don't drop
return false;
return result >= bufferType.getNumElements();
}
namespace {
struct RemoveStaticallyOobBufferLoads final
: public OpRewritePattern<RawBufferLoadOp> {
using OpRewritePattern<RawBufferLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(RawBufferLoadOp op,
PatternRewriter &rw) const override {
if (!staticallyOutOfBounds(op))
return failure();
Type loadType = op.getResult().getType();
rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
rw.getZeroAttr(loadType));
return success();
}
};
template <typename OpType>
struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
if (!staticallyOutOfBounds(op))
return failure();
rw.eraseOp(op);
return success();
}
};
} // end namespace
void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveStaticallyOobBufferLoads>(context);
}
void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
}
void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
}
//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
LogicalResult MFMAOp::verify() {
constexpr uint32_t waveSize = 64;
Builder b(getContext());
Type sourceType = getSourceA().getType();
Type destType = getDestC().getType();
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
if (auto sourceVector = sourceType.dyn_cast<VectorType>()) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
if (auto destVector = destType.dyn_cast<VectorType>()) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
// Normalize the wider integer types the compiler expects to i8
if (sourceElem.isInteger(32)) {
sourceLen *= 4;
sourceElem = b.getI8Type();
}
if (sourceElem.isInteger(64)) {
sourceLen *= 8;
sourceElem = b.getI8Type();
}
int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
if (sourceLen != numSourceElems)
return emitOpError("expected " + Twine(numSourceElems) +
" source values for this operation but got " +
Twine(sourceLen));
int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
if (destLen != numDestElems)
return emitOpError("expected " + Twine(numDestElems) +
" result values for this operation but got " +
Twine(destLen));
if (destElem.isF64() && getBlgp() != MFMAPermB::none)
return emitOpError(
"double-precision ops do not support permuting lanes of B");
if (destElem.isF64() && getCbsz() != 0)
return emitOpError(
"double-precision ops do not support permuting lanes of A");
if (getAbid() >= (1u << getCbsz()))
return emitOpError(
"block ID for permuting A (abid) must be below 2 ** cbsz");
if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
return emitOpError(
"negation flags only available for double-precision operations");
return success();
}
#include "mlir/Dialect/AMDGPU/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"