//===- Utils.cpp - Transform utilities ------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/NVGPU/Transforms/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; using namespace mlir::nvgpu; Operation::operand_range nvgpu::getIndices(Operation *op) { if (auto ldmatrixOp = dyn_cast(op)) return ldmatrixOp.getIndices(); if (auto copyOp = dyn_cast(op)) return copyOp.getDstIndices(); if (auto loadOp = dyn_cast(op)) return loadOp.getIndices(); if (auto storeOp = dyn_cast(op)) return storeOp.getIndices(); if (auto vectorReadOp = dyn_cast(op)) return vectorReadOp.getIndices(); if (auto vectorStoreOp = dyn_cast(op)) return vectorStoreOp.getIndices(); llvm_unreachable("unsupported op type"); } void nvgpu::setIndices(Operation *op, ArrayRef indices) { if (auto ldmatrixOp = dyn_cast(op)) return ldmatrixOp.getIndicesMutable().assign(indices); if (auto copyOp = dyn_cast(op)) return copyOp.getDstIndicesMutable().assign(indices); if (auto loadOp = dyn_cast(op)) return loadOp.getIndicesMutable().assign(indices); if (auto storeOp = dyn_cast(op)) return storeOp.getIndicesMutable().assign(indices); if (auto vectorReadOp = dyn_cast(op)) return vectorReadOp.getIndicesMutable().assign(indices); if (auto vectorStoreOp = dyn_cast(op)) return vectorStoreOp.getIndicesMutable().assign(indices); llvm_unreachable("unsupported op type"); }