Files
clang-p2996/mlir/test/python/dialects/transform_vector_ext.py
Alex Zinenko d517117303 [mlir] python bindings for vector transform ops
Provide Python bindings for transform ops defined in the vector dialect.
All of these ops are sufficiently simple that no mixins are necessary
for them to be nicely usable.

Reviewed By: ingomueller-net

Differential Revision: https://reviews.llvm.org/D156554
2023-07-31 15:42:59 +00:00

154 lines
6.5 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import vector
def run_apply_patterns(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
f()
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f
@run_apply_patterns
def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
vector.ApplyRankReducingSubviewPatternsOp()
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
vector.ApplyTransferPermutationPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_broadcast
vector.ApplyLowerBroadcastPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masks
vector.ApplyLowerMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
vector.ApplyLowerMaskedTransfersPatternsOp()
# CHECK: transform.apply_patterns.vector.materialize_masks
vector.ApplyMaterializeMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_outerproduct
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
vector.ApplyLowerShapeCastPatternsOp()
@run_apply_patterns
def configurable_patterns():
# CHECK-LABEL: TEST: configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.lower_transfer
# CHECK-SAME: max_transfer_rank = 4
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
# CHECK: transform.apply_patterns.vector.transfer_to_scf
# CHECK-SAME: max_transfer_rank = 3
# CHECK-SAME: full_unroll = true
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
@run_apply_patterns
def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.lower_contraction
vector.ApplyLowerContractionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = matmulintrinsics
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.MATMUL
)
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = parallelarith
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.PARALLEL_ARITH
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
vector.ApplyLowerMultiReductionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# This is the default mode, not printed.
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.INNER_PARALLEL
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.INNER_REDUCTION
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = eltwise
# CHECK-SAME: avx2_lowering_strategy = false
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = eltwise
# CHECK-SAME: avx2_lowering_strategy = false
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.ELT_WISE
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = flat_transpose
# CHECK-SAME: avx2_lowering_strategy = false
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.FLAT
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_1d
# CHECK-SAME: avx2_lowering_strategy = false
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.SHUFFLE1_D
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_16x16
# CHECK-SAME: avx2_lowering_strategy = false
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.SHUFFLE16X16
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = flat_transpose
# CHECK-SAME: avx2_lowering_strategy = true
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.FLAT,
avx2_lowering_strategy=True,
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
vector.ApplySplitTransferFullPartialPatternsOp()
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = none
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.NONE
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.VECTOR_TRANSFER
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# This is the default mode, not printed.
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.LINALG_COPY
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.FORCE_IN_BOUNDS
)