This commit shuffles SPIR-V code around to better follow MLIR
convention. Specifically,
* Created IR/, Transforms/, Linking/, and Utils/ subdirectories and
moved suitable code inside.
* Created SPIRVEnums.{h|cpp} for SPIR-V C/C++ enums generated from
SPIR-V spec. Previously they are cluttered inside SPIRVTypes.{h|cpp}.
* Fixed include guards in various header files (both .h and .td).
* Moved serialization tests under test/Target/SPIRV.
* Renamed TableGen backend -gen-spirv-op-utils into -gen-spirv-attr-utils
as it is only generating utility functions for attributes.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D93407
158 lines
6.3 KiB
C++
158 lines
6.3 KiB
C++
//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to generate SPIRV operations for Vector
|
|
// operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
|
|
#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct VectorBroadcastConvert final
|
|
: public SPIRVOpLowering<vector::BroadcastOp> {
|
|
using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
|
|
LogicalResult
|
|
matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (broadcastOp.source().getType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
|
|
return failure();
|
|
vector::BroadcastOp::Adaptor adaptor(operands);
|
|
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
|
|
adaptor.source());
|
|
Value construct = rewriter.create<spirv::CompositeConstructOp>(
|
|
broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
|
|
rewriter.replaceOp(broadcastOp, construct);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractOpConvert final
|
|
: public SPIRVOpLowering<vector::ExtractOp> {
|
|
using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (extractOp.getType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(extractOp.getVectorType()))
|
|
return failure();
|
|
vector::ExtractOp::Adaptor adaptor(operands);
|
|
int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
|
|
Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
|
|
extractOp.getLoc(), adaptor.vector(), id);
|
|
rewriter.replaceOp(extractOp, newExtract);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
|
|
using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (insertOp.getSourceType().isa<VectorType>() ||
|
|
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
|
|
return failure();
|
|
vector::InsertOp::Adaptor adaptor(operands);
|
|
int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
|
|
Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
|
|
insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
|
|
rewriter.replaceOp(insertOp, newInsert);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorExtractElementOpConvert final
|
|
: public SPIRVOpLowering<vector::ExtractElementOp> {
|
|
using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
|
|
LogicalResult
|
|
matchAndRewrite(vector::ExtractElementOp extractElementOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
|
|
return failure();
|
|
vector::ExtractElementOp::Adaptor adaptor(operands);
|
|
Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
|
|
extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
|
|
extractElementOp.position());
|
|
rewriter.replaceOp(extractElementOp, newExtractElement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorInsertElementOpConvert final
|
|
: public SPIRVOpLowering<vector::InsertElementOp> {
|
|
using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
|
|
LogicalResult
|
|
matchAndRewrite(vector::InsertElementOp insertElementOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
|
|
return failure();
|
|
vector::InsertElementOp::Adaptor adaptor(operands);
|
|
Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
|
|
insertElementOp.getLoc(), insertElementOp.getType(),
|
|
insertElementOp.dest(), adaptor.source(), insertElementOp.position());
|
|
rewriter.replaceOp(insertElementOp, newInsertElement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
|
|
SPIRVTypeConverter &typeConverter,
|
|
OwningRewritePatternList &patterns) {
|
|
patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
|
|
VectorInsertOpConvert, VectorExtractElementOpConvert,
|
|
VectorInsertElementOpConvert>(context, typeConverter);
|
|
}
|
|
|
|
namespace {
|
|
struct LowerVectorToSPIRVPass
|
|
: public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LowerVectorToSPIRVPass::runOnOperation() {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp module = getOperation();
|
|
|
|
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
|
std::unique_ptr<ConversionTarget> target =
|
|
spirv::SPIRVConversionTarget::get(targetAttr);
|
|
|
|
SPIRVTypeConverter typeConverter(targetAttr);
|
|
OwningRewritePatternList patterns;
|
|
populateVectorToSPIRVPatterns(context, typeConverter, patterns);
|
|
|
|
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
|
target->addLegalOp<FuncOp>();
|
|
|
|
if (failed(applyFullConversion(module, *target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::createConvertVectorToSPIRVPass() {
|
|
return std::make_unique<LowerVectorToSPIRVPass>();
|
|
}
|