|
|
|
|
@@ -1396,10 +1396,8 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|
|
|
|
: FuncOpConversionBase(converter) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto funcOp = cast<FuncOp>(op);
|
|
|
|
|
|
|
|
|
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
|
|
|
|
if (!newFuncOp)
|
|
|
|
|
return failure();
|
|
|
|
|
@@ -1407,14 +1405,14 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|
|
|
|
if (typeConverter.getOptions().emitCWrappers ||
|
|
|
|
|
funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
|
|
|
|
|
if (newFuncOp.isExternal())
|
|
|
|
|
wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
|
|
|
|
|
wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp,
|
|
|
|
|
newFuncOp);
|
|
|
|
|
else
|
|
|
|
|
wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp,
|
|
|
|
|
wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp,
|
|
|
|
|
newFuncOp);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -1425,10 +1423,8 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|
|
|
|
using FuncOpConversionBase::FuncOpConversionBase;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto funcOp = cast<FuncOp>(op);
|
|
|
|
|
|
|
|
|
|
// Store the type of memref-typed arguments before the conversion so that we
|
|
|
|
|
// can promote them to MemRef descriptor at the beginning of the function.
|
|
|
|
|
SmallVector<Type, 8> oldArgTypes =
|
|
|
|
|
@@ -1438,7 +1434,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|
|
|
|
if (!newFuncOp)
|
|
|
|
|
return failure();
|
|
|
|
|
if (newFuncOp.getBody().empty()) {
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1471,7 +1467,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|
|
|
|
// TODO: The placeholder is needed to avoid replacing barePtr uses in the
|
|
|
|
|
// MemRef descriptor instructions. We may want to have a utility in the
|
|
|
|
|
// rewriter to properly handle this use case.
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Location loc = funcOp.getLoc();
|
|
|
|
|
auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
|
|
|
|
|
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
|
|
|
|
|
|
|
|
|
|
@@ -1480,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|
|
|
|
rewriter.replaceOp(placeholder, {desc});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -1711,13 +1707,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
AssertOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
// Insert the `abort` declaration if necessary.
|
|
|
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
|
|
|
auto module = op.getParentOfType<ModuleOp>();
|
|
|
|
|
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
|
|
|
|
if (!abortFunc) {
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
@@ -1754,13 +1750,13 @@ struct CreateComplexOpLowering
|
|
|
|
|
using ConvertOpToLLVMPattern<CreateComplexOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(CreateComplexOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto complexOp = cast<CreateComplexOp>(op);
|
|
|
|
|
CreateComplexOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
// Pack real and imaginary part in a complex number struct.
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
auto structType = typeConverter.convertType(complexOp.getType());
|
|
|
|
|
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
|
|
|
|
|
complexStruct.setReal(rewriter, loc, transformed.real());
|
|
|
|
|
@@ -1775,13 +1771,13 @@ struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<ReOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(ReOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
ReOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
// Extract real part from the complex number struct.
|
|
|
|
|
ComplexStructBuilder complexStruct(transformed.complex());
|
|
|
|
|
Value real = complexStruct.real(rewriter, op->getLoc());
|
|
|
|
|
Value real = complexStruct.real(rewriter, op.getLoc());
|
|
|
|
|
rewriter.replaceOp(op, real);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
@@ -1792,13 +1788,13 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<ImOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(ImOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
ImOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
// Extract imaginary part from the complex number struct.
|
|
|
|
|
ComplexStructBuilder complexStruct(transformed.complex());
|
|
|
|
|
Value imaginary = complexStruct.imaginary(rewriter, op->getLoc());
|
|
|
|
|
Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
|
|
|
|
|
rewriter.replaceOp(op, imaginary);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
@@ -1833,9 +1829,8 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(AddCFOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto op = cast<AddCFOp>(operation);
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
BinaryComplexOperands arg =
|
|
|
|
|
unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
|
|
|
|
|
@@ -1861,9 +1856,8 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(SubCFOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto op = cast<SubCFOp>(operation);
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
BinaryComplexOperands arg =
|
|
|
|
|
unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
|
|
|
|
|
@@ -1889,9 +1883,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto op = cast<ConstantOp>(operation);
|
|
|
|
|
// If constant refers to a function, convert it to "addressof".
|
|
|
|
|
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
|
|
|
|
|
auto type = typeConverter.convertType(op.getResult().getType())
|
|
|
|
|
@@ -2284,10 +2277,9 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
|
|
|
|
using Base = ConvertOpToLLVMPattern<CallOpType>;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
typename CallOpType::Adaptor transformed(operands);
|
|
|
|
|
auto callOp = cast<CallOpType>(op);
|
|
|
|
|
|
|
|
|
|
// Pack the result types into a struct.
|
|
|
|
|
Type packedResult = nullptr;
|
|
|
|
|
@@ -2301,10 +2293,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto promoted = this->typeConverter.promoteOperands(
|
|
|
|
|
op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter);
|
|
|
|
|
callOp.getLoc(), /*opOperands=*/callOp.getOperation()->getOperands(),
|
|
|
|
|
operands, rewriter);
|
|
|
|
|
auto newOp = rewriter.create<LLVM::CallOp>(
|
|
|
|
|
op->getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
|
|
|
|
promoted, op->getAttrs());
|
|
|
|
|
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
|
|
|
|
promoted, callOp.getAttrs());
|
|
|
|
|
|
|
|
|
|
SmallVector<Value, 4> results;
|
|
|
|
|
if (numResults < 2) {
|
|
|
|
|
@@ -2315,9 +2308,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
|
|
|
|
// Extract individual results from the structure and return them as list.
|
|
|
|
|
results.reserve(numResults);
|
|
|
|
|
for (unsigned i = 0; i < numResults; ++i) {
|
|
|
|
|
auto type = this->typeConverter.convertType(op->getResult(i).getType());
|
|
|
|
|
auto type =
|
|
|
|
|
this->typeConverter.convertType(callOp.getResult(i).getType());
|
|
|
|
|
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
|
|
|
|
op->getLoc(), type, newOp.getOperation()->getResult(0),
|
|
|
|
|
callOp.getLoc(), type, newOp.getOperation()->getResult(0),
|
|
|
|
|
rewriter.getI64ArrayAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -2327,16 +2321,16 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
|
|
|
|
// descriptors.
|
|
|
|
|
assert(results.size() == resultTypes.size() &&
|
|
|
|
|
"The number of arguments and types doesn't match");
|
|
|
|
|
this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(),
|
|
|
|
|
resultTypes, results);
|
|
|
|
|
} else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(),
|
|
|
|
|
this->typeConverter.promoteBarePtrsToDescriptors(
|
|
|
|
|
rewriter, callOp.getLoc(), resultTypes, results);
|
|
|
|
|
} else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
|
|
|
|
|
this->typeConverter, resultTypes,
|
|
|
|
|
results,
|
|
|
|
|
/*toDynamic=*/false))) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, results);
|
|
|
|
|
rewriter.replaceOp(callOp, results);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -2359,18 +2353,18 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
|
|
|
|
|
: ConvertOpToLLVMPattern<DeallocOp>(converter) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
assert(operands.size() == 1 && "dealloc takes one operand");
|
|
|
|
|
DeallocOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
// Insert the `free` declaration if it is not already present.
|
|
|
|
|
auto freeFunc =
|
|
|
|
|
op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
|
|
|
|
|
op.getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
|
|
|
|
|
if (!freeFunc) {
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(
|
|
|
|
|
op->getParentOfType<ModuleOp>().getBody());
|
|
|
|
|
op.getParentOfType<ModuleOp>().getBody());
|
|
|
|
|
freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
|
|
|
|
|
rewriter.getUnknownLoc(), "free",
|
|
|
|
|
LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
|
|
|
|
|
@@ -2379,8 +2373,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
|
|
|
|
|
|
|
|
|
|
MemRefDescriptor memref(transformed.memref());
|
|
|
|
|
Value casted = rewriter.create<LLVM::BitcastOp>(
|
|
|
|
|
op->getLoc(), getVoidPtrType(),
|
|
|
|
|
memref.allocatedPtr(rewriter, op->getLoc()));
|
|
|
|
|
op.getLoc(), getVoidPtrType(),
|
|
|
|
|
memref.allocatedPtr(rewriter, op.getLoc()));
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
|
|
|
|
op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
|
|
|
|
return success();
|
|
|
|
|
@@ -2410,9 +2404,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto global = cast<GlobalMemrefOp>(op);
|
|
|
|
|
MemRefType type = global.type().cast<MemRefType>();
|
|
|
|
|
if (!isSupportedMemRefType(type))
|
|
|
|
|
return failure();
|
|
|
|
|
@@ -2434,7 +2427,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
|
|
|
|
op, arrayTy, global.constant(), linkage, global.sym_name(),
|
|
|
|
|
global, arrayTy, global.constant(), linkage, global.sym_name(),
|
|
|
|
|
initialValue, type.getMemorySpace());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -2491,7 +2484,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
RsqrtOp::Adaptor transformed(operands);
|
|
|
|
|
auto operandType =
|
|
|
|
|
@@ -2500,8 +2493,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
|
|
|
|
|
if (!operandType)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto resultType = *op->result_type_begin();
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
auto resultType = op.getResult().getType();
|
|
|
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
|
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
|
|
|
|
|
|
@@ -2524,7 +2517,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
return handleMultidimensionalVectors(
|
|
|
|
|
op, operands, typeConverter,
|
|
|
|
|
op.getOperation(), operands, typeConverter,
|
|
|
|
|
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
|
|
|
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
|
|
|
mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
|
|
|
|
|
@@ -2543,8 +2536,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
|
|
|
|
|
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult match(Operation *op) const override {
|
|
|
|
|
auto memRefCastOp = cast<MemRefCastOp>(op);
|
|
|
|
|
LogicalResult match(MemRefCastOp memRefCastOp) const override {
|
|
|
|
|
Type srcType = memRefCastOp.getOperand().getType();
|
|
|
|
|
Type dstType = memRefCastOp.getType();
|
|
|
|
|
|
|
|
|
|
@@ -2568,19 +2560,18 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
|
|
|
|
: failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto memRefCastOp = cast<MemRefCastOp>(op);
|
|
|
|
|
MemRefCastOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
auto srcType = memRefCastOp.getOperand().getType();
|
|
|
|
|
auto dstType = memRefCastOp.getType();
|
|
|
|
|
auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto loc = memRefCastOp.getLoc();
|
|
|
|
|
|
|
|
|
|
// For ranked/ranked case, just keep the original descriptor.
|
|
|
|
|
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
|
|
|
|
|
return rewriter.replaceOp(op, {transformed.source()});
|
|
|
|
|
return rewriter.replaceOp(memRefCastOp, {transformed.source()});
|
|
|
|
|
|
|
|
|
|
if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
|
|
|
|
|
// Casting ranked to unranked memref type
|
|
|
|
|
@@ -2607,7 +2598,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
|
|
|
|
memRefDesc.setRank(rewriter, loc, rankVal);
|
|
|
|
|
// d2 = InsertValueOp d1, voidptr, 1
|
|
|
|
|
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
|
|
|
|
|
rewriter.replaceOp(op, (Value)memRefDesc);
|
|
|
|
|
rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
|
|
|
|
|
|
|
|
|
|
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
|
|
|
|
|
// Casting from unranked type to ranked.
|
|
|
|
|
@@ -2625,7 +2616,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
|
|
|
|
.getResult();
|
|
|
|
|
// struct = LoadOp castPtr
|
|
|
|
|
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
|
|
|
|
|
rewriter.replaceOp(op, loadOp.getResult());
|
|
|
|
|
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
|
|
|
|
|
} else {
|
|
|
|
|
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
|
|
|
|
|
}
|
|
|
|
|
@@ -2680,17 +2671,17 @@ struct MemRefReinterpretCastOpLowering
|
|
|
|
|
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto castOp = cast<MemRefReinterpretCastOp>(op);
|
|
|
|
|
MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
|
|
|
|
|
MemRefReinterpretCastOp::Adaptor adaptor(
|
|
|
|
|
operands, castOp.getOperation()->getAttrDictionary());
|
|
|
|
|
Type srcType = castOp.source().getType();
|
|
|
|
|
|
|
|
|
|
Value descriptor;
|
|
|
|
|
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
|
|
|
|
|
adaptor, &descriptor)))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOp(op, {descriptor});
|
|
|
|
|
rewriter.replaceOp(castOp, {descriptor});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -2748,10 +2739,9 @@ struct MemRefReshapeOpLowering
|
|
|
|
|
using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto reshapeOp = cast<MemRefReshapeOp>(op);
|
|
|
|
|
|
|
|
|
|
auto *op = reshapeOp.getOperation();
|
|
|
|
|
MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
|
|
|
|
|
Type srcType = reshapeOp.source().getType();
|
|
|
|
|
|
|
|
|
|
@@ -2898,15 +2888,14 @@ struct DialectCastOpLowering
|
|
|
|
|
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto castOp = cast<LLVM::DialectCastOp>(op);
|
|
|
|
|
LLVM::DialectCastOp::Adaptor transformed(operands);
|
|
|
|
|
if (transformed.in().getType() !=
|
|
|
|
|
typeConverter.convertType(castOp.getType())) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, transformed.in());
|
|
|
|
|
rewriter.replaceOp(castOp, transformed.in());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -2917,19 +2906,18 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto dimOp = cast<DimOp>(op);
|
|
|
|
|
Type operandType = dimOp.memrefOrTensor().getType();
|
|
|
|
|
if (operandType.isa<UnrankedMemRefType>()) {
|
|
|
|
|
rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp,
|
|
|
|
|
operands, rewriter)});
|
|
|
|
|
rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
|
|
|
|
|
operandType, dimOp, operands, rewriter)});
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
if (operandType.isa<MemRefType>()) {
|
|
|
|
|
rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp,
|
|
|
|
|
operands, rewriter)});
|
|
|
|
|
rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
|
|
|
|
|
operandType, dimOp, operands, rewriter)});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
return failure();
|
|
|
|
|
@@ -3006,10 +2994,10 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(RankOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Type operandType = cast<RankOp>(op).memrefOrTensor().getType();
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Type operandType = op.memrefOrTensor().getType();
|
|
|
|
|
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
|
|
|
|
|
UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
|
|
|
|
|
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
|
|
|
|
|
@@ -3033,8 +3021,8 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|
|
|
|
using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
|
|
|
|
|
using Base = LoadStoreOpLowering<Derived>;
|
|
|
|
|
|
|
|
|
|
LogicalResult match(Operation *op) const override {
|
|
|
|
|
MemRefType type = cast<Derived>(op).getMemRefType();
|
|
|
|
|
LogicalResult match(Derived op) const override {
|
|
|
|
|
MemRefType type = op.getMemRefType();
|
|
|
|
|
return isSupportedMemRefType(type) ? success() : failure();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3045,16 +3033,15 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
|
|
|
|
using Base::Base;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loadOp = cast<LoadOp>(op);
|
|
|
|
|
LoadOp::Adaptor transformed(operands);
|
|
|
|
|
auto type = loadOp.getMemRefType();
|
|
|
|
|
|
|
|
|
|
Value dataPtr =
|
|
|
|
|
getStridedElementPtr(op->getLoc(), type, transformed.memref(),
|
|
|
|
|
getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
|
|
|
|
|
transformed.indices(), rewriter);
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3065,13 +3052,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
|
|
|
|
using Base::Base;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto type = cast<StoreOp>(op).getMemRefType();
|
|
|
|
|
auto type = op.getMemRefType();
|
|
|
|
|
StoreOp::Adaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
Value dataPtr =
|
|
|
|
|
getStridedElementPtr(op->getLoc(), type, transformed.memref(),
|
|
|
|
|
getStridedElementPtr(op.getLoc(), type, transformed.memref(),
|
|
|
|
|
transformed.indices(), rewriter);
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
|
|
|
|
|
dataPtr);
|
|
|
|
|
@@ -3085,29 +3072,26 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
|
|
|
|
|
using Base::Base;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto prefetchOp = cast<PrefetchOp>(op);
|
|
|
|
|
PrefetchOp::Adaptor transformed(operands);
|
|
|
|
|
auto type = prefetchOp.getMemRefType();
|
|
|
|
|
auto loc = prefetchOp.getLoc();
|
|
|
|
|
|
|
|
|
|
Value dataPtr =
|
|
|
|
|
getStridedElementPtr(op->getLoc(), type, transformed.memref(),
|
|
|
|
|
transformed.indices(), rewriter);
|
|
|
|
|
Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
|
|
|
|
|
transformed.indices(), rewriter);
|
|
|
|
|
|
|
|
|
|
// Replace with llvm.prefetch.
|
|
|
|
|
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
|
|
|
|
|
auto isWrite = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
op->getLoc(), llvmI32Type,
|
|
|
|
|
rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
|
|
|
|
|
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
|
|
|
|
|
auto localityHint = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
op->getLoc(), llvmI32Type,
|
|
|
|
|
loc, llvmI32Type,
|
|
|
|
|
rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
|
|
|
|
|
auto isData = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
op->getLoc(), llvmI32Type,
|
|
|
|
|
rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
|
|
|
|
|
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
|
|
|
|
|
localityHint, isData);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -3121,10 +3105,9 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
IndexCastOpAdaptor transformed(operands);
|
|
|
|
|
auto indexCastOp = cast<IndexCastOp>(op);
|
|
|
|
|
|
|
|
|
|
auto targetType =
|
|
|
|
|
this->typeConverter.convertType(indexCastOp.getResult().getType())
|
|
|
|
|
@@ -3134,12 +3117,12 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
|
|
|
|
|
unsigned sourceBits = sourceType.getIntegerBitWidth();
|
|
|
|
|
|
|
|
|
|
if (targetBits == sourceBits)
|
|
|
|
|
rewriter.replaceOp(op, transformed.in());
|
|
|
|
|
rewriter.replaceOp(indexCastOp, transformed.in());
|
|
|
|
|
else if (targetBits < sourceBits)
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
|
|
|
|
|
transformed.in());
|
|
|
|
|
else
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
|
|
|
|
|
transformed.in());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -3156,13 +3139,12 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto cmpiOp = cast<CmpIOp>(op);
|
|
|
|
|
CmpIOpAdaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
|
|
|
|
|
op, typeConverter.convertType(cmpiOp.getResult().getType()),
|
|
|
|
|
cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()),
|
|
|
|
|
rewriter.getI64IntegerAttr(static_cast<int64_t>(
|
|
|
|
|
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
|
|
|
|
|
transformed.lhs(), transformed.rhs());
|
|
|
|
|
@@ -3175,13 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto cmpfOp = cast<CmpFOp>(op);
|
|
|
|
|
CmpFOpAdaptor transformed(operands);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
|
|
|
|
|
op, typeConverter.convertType(cmpfOp.getResult().getType()),
|
|
|
|
|
cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()),
|
|
|
|
|
rewriter.getI64IntegerAttr(static_cast<int64_t>(
|
|
|
|
|
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
|
|
|
|
|
transformed.lhs(), transformed.rhs());
|
|
|
|
|
@@ -3243,10 +3224,10 @@ struct OneToOneLLVMTerminatorLowering
|
|
|
|
|
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
|
|
|
|
|
op->getAttrs());
|
|
|
|
|
rewriter.replaceOpWithNewOp<TargetOp>(
|
|
|
|
|
op, operands, op.getOperation()->getSuccessors(), op.getAttrs());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3261,16 +3242,16 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
unsigned numArguments = op->getNumOperands();
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
unsigned numArguments = op.getNumOperands();
|
|
|
|
|
SmallVector<Value, 4> updatedOperands;
|
|
|
|
|
|
|
|
|
|
if (typeConverter.getOptions().useBarePtrCallConv) {
|
|
|
|
|
// For the bare-ptr calling convention, extract the aligned pointer to
|
|
|
|
|
// be returned from the memref descriptor.
|
|
|
|
|
for (auto it : llvm::zip(op->getOperands(), operands)) {
|
|
|
|
|
for (auto it : llvm::zip(op.getOperation()->getOperands(), operands)) {
|
|
|
|
|
Type oldTy = std::get<0>(it).getType();
|
|
|
|
|
Value newOperand = std::get<1>(it);
|
|
|
|
|
if (oldTy.isa<MemRefType>()) {
|
|
|
|
|
@@ -3286,26 +3267,26 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
|
|
|
|
|
} else {
|
|
|
|
|
updatedOperands = llvm::to_vector<4>(operands);
|
|
|
|
|
copyUnrankedDescriptors(rewriter, loc, typeConverter,
|
|
|
|
|
op->getOperands().getTypes(), updatedOperands,
|
|
|
|
|
op.getOperands().getTypes(), updatedOperands,
|
|
|
|
|
/*toDynamic=*/true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If ReturnOp has 0 or 1 operand, create it and return immediately.
|
|
|
|
|
if (numArguments == 0) {
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
|
|
|
|
op->getAttrs());
|
|
|
|
|
op.getAttrs());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
if (numArguments == 1) {
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
|
|
|
|
op, TypeRange(), updatedOperands, op->getAttrs());
|
|
|
|
|
op, TypeRange(), updatedOperands, op.getAttrs());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Otherwise, we need to pack the arguments into an LLVM struct type before
|
|
|
|
|
// returning.
|
|
|
|
|
auto packedType = typeConverter.packFunctionResults(
|
|
|
|
|
llvm::to_vector<4>(op->getOperandTypes()));
|
|
|
|
|
llvm::to_vector<4>(op.getOperandTypes()));
|
|
|
|
|
|
|
|
|
|
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
|
|
|
|
|
for (unsigned i = 0; i < numArguments; ++i) {
|
|
|
|
|
@@ -3314,7 +3295,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
|
|
|
|
|
rewriter.getI64ArrayAttr(i));
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
|
|
|
|
|
op->getAttrs());
|
|
|
|
|
op.getAttrs());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3335,29 +3316,30 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto splatOp = cast<SplatOp>(op);
|
|
|
|
|
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
|
|
|
|
if (!resultType || resultType.getRank() != 1)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// First insert it into an undef vector so we can shuffle it.
|
|
|
|
|
auto vectorType = typeConverter.convertType(splatOp.getType());
|
|
|
|
|
Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
|
|
|
|
|
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
|
|
|
|
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)),
|
|
|
|
|
splatOp.getLoc(),
|
|
|
|
|
typeConverter.convertType(rewriter.getIntegerType(32)),
|
|
|
|
|
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
|
|
|
|
|
|
|
|
|
|
auto v = rewriter.create<LLVM::InsertElementOp>(
|
|
|
|
|
op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
|
|
|
|
|
splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero);
|
|
|
|
|
|
|
|
|
|
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
|
|
|
|
|
SmallVector<int32_t, 4> zeroValues(width, 0);
|
|
|
|
|
|
|
|
|
|
// Shuffle the value across the desired number of elements.
|
|
|
|
|
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
|
|
|
|
|
zeroAttrs);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3369,16 +3351,15 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto splatOp = cast<SplatOp>(op);
|
|
|
|
|
SplatOp::Adaptor adaptor(operands);
|
|
|
|
|
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
|
|
|
|
if (!resultType || resultType.getRank() == 1)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// First insert it into an undef vector so we can shuffle it.
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto loc = splatOp.getLoc();
|
|
|
|
|
auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter);
|
|
|
|
|
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
|
|
|
|
|
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
|
|
|
|
|
@@ -3409,7 +3390,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|
|
|
|
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
|
|
|
|
|
position);
|
|
|
|
|
});
|
|
|
|
|
rewriter.replaceOp(op, desc);
|
|
|
|
|
rewriter.replaceOp(splatOp, desc);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3431,10 +3412,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
|
|
|
|
using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto subViewOp = cast<SubViewOp>(op);
|
|
|
|
|
auto loc = subViewOp.getLoc();
|
|
|
|
|
|
|
|
|
|
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
|
|
|
|
|
auto sourceElementTy =
|
|
|
|
|
@@ -3545,7 +3525,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
|
|
|
|
j--;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, {targetMemRef});
|
|
|
|
|
rewriter.replaceOp(subViewOp, {targetMemRef});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3562,16 +3542,15 @@ public:
|
|
|
|
|
using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto loc = transposeOp.getLoc();
|
|
|
|
|
TransposeOpAdaptor adaptor(operands);
|
|
|
|
|
MemRefDescriptor viewMemRef(adaptor.in());
|
|
|
|
|
|
|
|
|
|
auto transposeOp = cast<TransposeOp>(op);
|
|
|
|
|
// No permutation, early exit.
|
|
|
|
|
if (transposeOp.permutation().isIdentity())
|
|
|
|
|
return rewriter.replaceOp(op, {viewMemRef}), success();
|
|
|
|
|
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
|
|
|
|
|
|
|
|
|
|
auto targetMemRef = MemRefDescriptor::undef(
|
|
|
|
|
rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
|
|
|
|
|
@@ -3596,7 +3575,7 @@ public:
|
|
|
|
|
viewMemRef.stride(rewriter, loc, sourcePos));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, {targetMemRef});
|
|
|
|
|
rewriter.replaceOp(transposeOp, {targetMemRef});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3643,10 +3622,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto viewOp = cast<ViewOp>(op);
|
|
|
|
|
auto loc = viewOp.getLoc();
|
|
|
|
|
ViewOpAdaptor adaptor(operands);
|
|
|
|
|
|
|
|
|
|
auto viewMemRefType = viewOp.getType();
|
|
|
|
|
@@ -3656,14 +3634,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|
|
|
|
auto targetDescTy =
|
|
|
|
|
typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
|
|
|
|
|
if (!targetDescTy)
|
|
|
|
|
return op->emitWarning("Target descriptor type not converted to LLVM"),
|
|
|
|
|
return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
|
|
|
|
|
failure();
|
|
|
|
|
|
|
|
|
|
int64_t offset;
|
|
|
|
|
SmallVector<int64_t, 4> strides;
|
|
|
|
|
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
|
|
|
|
|
if (failed(successStrides))
|
|
|
|
|
return op->emitWarning("cannot cast to non-strided shape"), failure();
|
|
|
|
|
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
|
|
|
|
|
assert(offset == 0 && "expected offset to be 0");
|
|
|
|
|
|
|
|
|
|
// Create the descriptor.
|
|
|
|
|
@@ -3695,11 +3673,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|
|
|
|
|
|
|
|
|
// Early exit for 0-D corner case.
|
|
|
|
|
if (viewMemRefType.getRank() == 0)
|
|
|
|
|
return rewriter.replaceOp(op, {targetMemRef}), success();
|
|
|
|
|
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
|
|
|
|
|
|
|
|
|
|
// Fields 4 and 5: Update sizes and strides.
|
|
|
|
|
if (strides.back() != 1)
|
|
|
|
|
return op->emitWarning("cannot cast to non-contiguous shape"), failure();
|
|
|
|
|
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
|
|
|
|
|
failure();
|
|
|
|
|
Value stride = nullptr, nextSize = nullptr;
|
|
|
|
|
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
|
|
|
|
// Update size.
|
|
|
|
|
@@ -3712,7 +3691,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|
|
|
|
nextSize = size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, {targetMemRef});
|
|
|
|
|
rewriter.replaceOp(viewOp, {targetMemRef});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@@ -3722,11 +3701,12 @@ struct AssumeAlignmentOpLowering
|
|
|
|
|
using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
AssumeAlignmentOp::Adaptor transformed(operands);
|
|
|
|
|
Value memref = transformed.memref();
|
|
|
|
|
unsigned alignment = cast<AssumeAlignmentOp>(op).alignment();
|
|
|
|
|
unsigned alignment = op.alignment();
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
MemRefDescriptor memRefDescriptor(memref);
|
|
|
|
|
Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
|
|
|
|
|
@@ -3741,16 +3721,14 @@ struct AssumeAlignmentOpLowering
|
|
|
|
|
// pointer SSA value.
|
|
|
|
|
auto intPtrType =
|
|
|
|
|
getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
|
|
|
|
|
Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0);
|
|
|
|
|
Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType,
|
|
|
|
|
alignment - 1);
|
|
|
|
|
Value ptrValue =
|
|
|
|
|
rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), intPtrType, ptr);
|
|
|
|
|
Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
|
|
|
|
|
Value mask =
|
|
|
|
|
createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
|
|
|
|
|
Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
|
|
|
|
|
rewriter.create<LLVM::AssumeOp>(
|
|
|
|
|
op->getLoc(),
|
|
|
|
|
rewriter.create<LLVM::ICmpOp>(
|
|
|
|
|
op->getLoc(), LLVM::ICmpPredicate::eq,
|
|
|
|
|
rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
|
|
|
|
|
loc, rewriter.create<LLVM::ICmpOp>(
|
|
|
|
|
loc, LLVM::ICmpPredicate::eq,
|
|
|
|
|
rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
|
|
|
|
|
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
return success();
|
|
|
|
|
@@ -3789,9 +3767,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|
|
|
|
using Base::Base;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto atomicOp = cast<AtomicRMWOp>(op);
|
|
|
|
|
if (failed(match(atomicOp)))
|
|
|
|
|
return failure();
|
|
|
|
|
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
|
|
|
|
if (!maybeKind)
|
|
|
|
|
return failure();
|
|
|
|
|
@@ -3799,10 +3778,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|
|
|
|
auto resultType = adaptor.value().getType();
|
|
|
|
|
auto memRefType = atomicOp.getMemRefType();
|
|
|
|
|
auto dataPtr =
|
|
|
|
|
getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(),
|
|
|
|
|
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
|
|
|
|
|
adaptor.indices(), rewriter);
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
|
|
|
|
op, resultType, *maybeKind, dataPtr, adaptor.value(),
|
|
|
|
|
atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
|
|
|
|
|
LLVM::AtomicOrdering::acq_rel);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
@@ -3840,11 +3819,10 @@ struct GenericAtomicRMWOpLowering
|
|
|
|
|
using Base::Base;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
|
matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto atomicOp = cast<GenericAtomicRMWOp>(op);
|
|
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto loc = atomicOp.getLoc();
|
|
|
|
|
GenericAtomicRMWOp::Adaptor adaptor(operands);
|
|
|
|
|
LLVM::LLVMType valueType =
|
|
|
|
|
typeConverter.convertType(atomicOp.getResult().getType())
|
|
|
|
|
@@ -3908,7 +3886,7 @@ struct GenericAtomicRMWOpLowering
|
|
|
|
|
std::next(opsToMoveEnd), rewriter);
|
|
|
|
|
|
|
|
|
|
// The 'result' of the atomic_rmw op is the newly loaded value.
|
|
|
|
|
rewriter.replaceOp(op, {newLoaded});
|
|
|
|
|
rewriter.replaceOp(atomicOp, {newLoaded});
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|