Files
clang-p2996/mlir/test/CAPI/sparse_tensor.c
wren romano 933fefb6a8 [mlir][sparse] Adjusting DimLevelType numeric values for faster predicates
This differential adjusts the numeric values for DimLevelType values: using the low-order two bits for recording the "No" and "Nu" properties, and the high-order bits for the formats per se.  (The choice of encoding may seem a bit peculiar, since the bits are mapped to negative properties rather than positive properties.  But this was done in order to preserve the collation order of DimLevelType values.  If we don't care about collation order, then we may prefer to flip the semantics of the property bits, so that they're less surprising to readers.)

Using distinguished bits for the properties and formats enables faster implementation for the predicates detecting those properties/formats, which matters because this is in the runtime library itself (rather than on the codegen side of things).  This differential pushes through the changes to the enum values, and optimizes the basic predicates.  However it does not optimize all the places where we check compound predicates (e.g., "is compressed or singleton"), to help reduce rebasing conflict with D134933.  Those optimizations will be done after this differential and D134933 are landed.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135004
2022-10-05 17:40:38 -07:00

87 lines
3.1 KiB
C

//===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/IR.h"
#include "mlir-c/RegisterEverything.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// CHECK-LABEL: testRoundtripEncoding()
static int testRoundtripEncoding(MlirContext ctx) {
fprintf(stderr, "testRoundtripEncoding()\n");
// clang-format off
const char *originalAsm =
"#sparse_tensor.encoding<{ "
"dimLevelType = [ \"dense\", \"compressed\", \"compressed\"], "
"dimOrdering = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, "
"higherOrdering = affine_map<(d0, d1)[s0] -> (s0, d0, d1)>, "
"pointerBitWidth = 32, indexBitWidth = 64 }>";
// clang-format on
MlirAttribute originalAttr =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(originalAsm));
// CHECK: isa: 1
fprintf(stderr, "isa: %d\n",
mlirAttributeIsASparseTensorEncodingAttr(originalAttr));
MlirAffineMap dimOrdering =
mlirSparseTensorEncodingAttrGetDimOrdering(originalAttr);
// CHECK: (d0, d1, d2) -> (d0, d1, d2)
mlirAffineMapDump(dimOrdering);
MlirAffineMap higherOrdering =
mlirSparseTensorEncodingAttrGetHigherOrdering(originalAttr);
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(higherOrdering);
// CHECK: level_type: 4
// CHECK: level_type: 8
// CHECK: level_type: 8
int numLevelTypes = mlirSparseTensorEncodingGetNumDimLevelTypes(originalAttr);
enum MlirSparseTensorDimLevelType *levelTypes =
malloc(sizeof(enum MlirSparseTensorDimLevelType) * numLevelTypes);
for (int i = 0; i < numLevelTypes; ++i) {
levelTypes[i] =
mlirSparseTensorEncodingAttrGetDimLevelType(originalAttr, i);
fprintf(stderr, "level_type: %d\n", levelTypes[i]);
}
// CHECK: pointer: 32
int pointerBitWidth =
mlirSparseTensorEncodingAttrGetPointerBitWidth(originalAttr);
fprintf(stderr, "pointer: %d\n", pointerBitWidth);
// CHECK: index: 64
int indexBitWidth =
mlirSparseTensorEncodingAttrGetIndexBitWidth(originalAttr);
fprintf(stderr, "index: %d\n", indexBitWidth);
MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
ctx, numLevelTypes, levelTypes, dimOrdering, higherOrdering,
pointerBitWidth, indexBitWidth);
mlirAttributeDump(newAttr); // For debugging filecheck output.
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
free(levelTypes);
return 0;
}
int main() {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(),
ctx);
if (testRoundtripEncoding(ctx))
return 1;
mlirContextDestroy(ctx);
return 0;
}