Files
clang-p2996/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
River Riddle 03d136cf5f [mlir] Promote the SubElementInterfaces to a core Attribute/Type construct
This commit restructures the sub element infrastructure to be a core part
of attributes and types, instead of being relegated to an interface. This
establishes sub element walking/replacement as something "always there",
which makes it easier to rely on for correctness/etc (which various bits of
infrastructure want, such as Symbols).

Attribute/Type now have `walk` and `replace` methods directly
accessible, which provide power API for interacting with sub elements. As
part of this, a new AttrTypeWalker class is introduced that supports caching
walked attributes/types, and a friendlier API (see the simplification of symbol
walking in SymbolTable.cpp).

Differential Revision: https://reviews.llvm.org/D142272
2023-01-27 15:28:03 -08:00

180 lines
6.5 KiB
C++

//===- LowerMemorySpaceAttributes.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// Implementation of a pass that rewrites the IR so that uses of
/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
/// with caller-specified numeric values.
///
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::gpu;
//===----------------------------------------------------------------------===//
// Conversion Target
//===----------------------------------------------------------------------===//
/// Returns true if the given `type` is considered as legal during memory space
/// attribute lowering.
static bool isLegalType(Type type) {
if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
return !memRefType.getMemorySpace()
.isa_and_nonnull<gpu::AddressSpaceAttr>();
}
return true;
}
/// Returns true if the given `attr` is considered legal during memory space
/// attribute lowering.
static bool isLegalAttr(Attribute attr) {
if (auto typeAttr = attr.dyn_cast<TypeAttr>())
return isLegalType(typeAttr.getValue());
return true;
}
/// Returns true if the given `op` is legal during memory space attribute
/// lowering.
static bool isLegalOp(Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
isLegalType);
}
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
return attr.getValue();
});
return llvm::all_of(op->getOperandTypes(), isLegalType) &&
llvm::all_of(op->getResultTypes(), isLegalType) &&
llvm::all_of(attrs, isLegalAttr);
}
void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
target.markUnknownOpDynamicallyLegal(isLegalOp);
}
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addConversion([mapping](Type type) {
return type.replace([mapping](Attribute attr) -> std::optional<Attribute> {
auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
if (!memorySpaceAttr)
return std::nullopt;
auto newValue = wrapNumericMemorySpace(
attr.getContext(), mapping(memorySpaceAttr.getValue()));
return newValue;
});
});
}
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<NamedAttribute> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (auto attr : op->getAttrs()) {
if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
newAttrs.push_back(attr);
}
}
SmallVector<Type> newResults;
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());
for (Region &region : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
} // namespace
void mlir::gpu::populateMemorySpaceLoweringPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
typeConverter);
}
namespace {
class LowerMemorySpaceAttributesPass
: public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
LowerMemorySpaceAttributesPass> {
public:
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
ConversionTarget target(getContext());
populateLowerMemorySpaceOpLegality(target);
TypeConverter typeConverter;
typeConverter.addConversion([](Type t) { return t; });
populateMemorySpaceAttributeTypeConversions(
typeConverter, [this](AddressSpace space) -> unsigned {
switch (space) {
case AddressSpace::Global:
return globalAddrSpace;
case AddressSpace::Workgroup:
return workgroupAddrSpace;
case AddressSpace::Private:
return privateAddrSpace;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
RewritePatternSet patterns(context);
populateMemorySpaceLoweringPatterns(typeConverter, patterns);
if (failed(applyFullConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace