[mlir] Expose simplifyAffineExpr through python api (#133926)
This commit is contained in:
@@ -104,6 +104,16 @@ MLIR_CAPI_EXPORTED MlirAffineExpr
|
||||
mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
|
||||
uint32_t shift, uint32_t offset);
|
||||
|
||||
/// Simplify an affine expression by flattening and some amount of simple
|
||||
/// analysis. This has complexity linear in the number of nodes in 'expr'.
|
||||
/// Returns the simplified expression, which is the same as the input expression
|
||||
/// if it can't be simplified. When `expr` is semi-affine, a simplified
|
||||
/// semi-affine expression is constructed in the sorted order of dimension and
|
||||
/// symbol positions.
|
||||
MLIR_CAPI_EXPORTED MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr,
|
||||
uint32_t numDims,
|
||||
uint32_t numSymbols);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Affine Dimension Expression.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -599,6 +599,16 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
|
||||
},
|
||||
nb::arg("num_symbols"), nb::arg("shift"),
|
||||
nb::arg("offset").none() = 0)
|
||||
.def_static(
|
||||
"simplify_affine_expr",
|
||||
[](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) {
|
||||
return PyAffineExpr(
|
||||
self.getContext(),
|
||||
mlirSimplifyAffineExpr(self, numDims, numSymbols));
|
||||
},
|
||||
nb::arg("expr"), nb::arg("num_dims"), nb::arg("num_symbols"),
|
||||
"Simplify an affine expression by flattening and some amount of "
|
||||
"simple analysis.")
|
||||
.def_static(
|
||||
"get_add", &PyAffineAddExpr::get,
|
||||
"Gets an affine expression containing a sum of two expressions.")
|
||||
|
||||
@@ -73,6 +73,11 @@ MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
|
||||
return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
|
||||
}
|
||||
|
||||
MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr, uint32_t numDims,
|
||||
uint32_t numSymbols) {
|
||||
return wrap(simplifyAffineExpr(unwrap(expr), numDims, numSymbols));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Affine Dimension Expression.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -416,3 +416,11 @@ def testAffineExprShift():
|
||||
|
||||
assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
|
||||
assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprSimplify
|
||||
@run
|
||||
def testAffineExprSimplify():
|
||||
with Context() as ctx:
|
||||
expr = AffineExpr.get_dim(0) + AffineExpr.get_symbol(0)
|
||||
assert expr == AffineExpr.simplify_affine_expr(expr, 1, 1)
|
||||
|
||||
Reference in New Issue
Block a user