Files
clang-p2996/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
River Riddle 3fffffa882 [mlir][Pattern] Add a new FrozenRewritePatternList class
This class represents a rewrite pattern list that has been frozen, and thus immutable. This replaces the uses of OwningRewritePatternList in pattern driver related API, such as dialect conversion. When PDL becomes more prevalent, this API will allow for optimizing a set of patterns once without the need to do this per run of a pass.

Differential Revision: https://reviews.llvm.org/D89104
2020-10-26 18:01:06 -07:00

120 lines
4.6 KiB
C++

//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to generate SPIRV operations for Vector
// operations.
//
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
struct VectorBroadcastConvert final
: public SPIRVOpLowering<vector::BroadcastOp> {
using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (broadcastOp.source().getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
return failure();
vector::BroadcastOp::Adaptor adaptor(operands);
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
adaptor.source());
Value construct = rewriter.create<spirv::CompositeConstructOp>(
broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
rewriter.replaceOp(broadcastOp, construct);
return success();
}
};
struct VectorExtractOpConvert final
: public SPIRVOpLowering<vector::ExtractOp> {
using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (extractOp.getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(extractOp.getVectorType()))
return failure();
vector::ExtractOp::Adaptor adaptor(operands);
int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
extractOp.getLoc(), adaptor.vector(), id);
rewriter.replaceOp(extractOp, newExtract);
return success();
}
};
struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
vector::InsertOp::Adaptor adaptor(operands);
int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
rewriter.replaceOp(insertOp, newInsert);
return success();
}
};
} // namespace
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
VectorInsertOpConvert>(context, typeConverter);
}
namespace {
struct LowerVectorToSPIRVPass
: public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
void runOnOperation() override;
};
} // namespace
void LowerVectorToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
OwningRewritePatternList patterns;
populateVectorToSPIRVPatterns(context, typeConverter, patterns);
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
target->addLegalOp<FuncOp>();
if (failed(applyFullConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToSPIRVPass() {
return std::make_unique<LowerVectorToSPIRVPass>();
}