[MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)

Add support for load/store with chunk_size, which requires special
consideration for the operand blocking since offests and masks are
 n-D and tensor are n+1-D. Support operations including create_tdesc,
update_tdesc, load, store, and prefetch.

---------

Co-authored-by: Adam Siemieniuk <adam.siemieniuk@intel.com>
This commit is contained in:
Jianhui Li
2025-06-17 15:46:35 -07:00
committed by GitHub
parent 3f33c8482f
commit f25f2f7de4
3 changed files with 315 additions and 124 deletions

View File

@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();
// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
if (originalChunkSize > 1)
targetIndiceShape.pop_back();
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();
SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
getUnrolledTypes(indiceVecTy, targetIndiceShape);
SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
SmallVector<Value> newOps;
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
op.getSource(), indice);
newOps.push_back(newOp);
// More indices is need when chunkSize > 1. Since a big load from one
// address could be break into multiple small loads.
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
for (auto [indice, indiceType] :
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
for (int64_t i = 0; i < numNewChunks; ++i) {
// Compute the offset
Value inc = rewriter.create<arith::ConstantIndexOp>(
loc, i * blockedChunkSize);
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
Value offsetIndice =
rewriter.create<arith::AddIOp>(loc, indice, incVec);
auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), offsetIndice);
newOps.push_back(newOp);
}
}
} else {
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), indice);
newOps.push_back(newOp);
}
}
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
SmallVector<int64_t> targetMaskShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
if (originalChunkSize > 1) {
targetMaskShape.pop_back();
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
SmallVector<Value> convertedMasks1D = pack(
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i)
convertedMasks.push_back(mask);
}
// This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
loc, rewriter);
}
SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
return success();
}
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
SmallVector<Value> convertedMasks1D = pack(
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
// This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
}
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (tdescTy.getRank() > 2)
return failure();
if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
VectorType offsetVecTy = offsetVec.getType();
SmallVector<Type> convertedOffsetTypes =
getUnrolledTypes(offsetVecTy, *targetShape);
SmallVector<Value> convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
int64_t originalChunkSize = tdescTy.getChunkSize();
if (originalChunkSize > 1) {
SmallVector<int64_t> shape1D(targetShape->begin(),
targetShape->end() - 1);
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
SmallVector<Value> convertedOffsetVec1D =
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
for (auto offset : convertedOffsetVec1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedOffsetVec.push_back(offset);
}
}
} else {
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
}
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);

View File

@@ -2,7 +2,7 @@
gpu.module @test {
// CHECK-LABEL: test_create_nd_tdesc
// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
@@ -10,31 +10,31 @@ gpu.module @test {
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
}
//-----
// CHECK-LABEL: test_create_nd_tdesc_1d
// CHECK-LABEL: create_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
}
//-----
// CHECK-LABEL: test_update_nd_tdesc
// CHECK-LABEL: update_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -42,11 +42,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_update_nd_tdesc_1d
// CHECK-LABEL: update_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
%update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
@@ -54,11 +54,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_prefetch_nd_tdesc
// CHECK-LABEL: prefetch_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return
@@ -66,23 +66,23 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_prefetch_nd_tdesc_1d
// CHECK-LABEL: prefetch_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
gpu.return
}
//-----
// CHECK-LABEL: test_load_nd
// CHECK-LABEL: load_nd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
// CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
@@ -90,12 +90,12 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_load_nd_1d
// CHECK-LABEL: load_nd_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
// CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
%data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
gpu.return %data : vector<64xf32>
@@ -103,11 +103,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_store_nd
// CHECK-LABEL: store_nd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.func @test_store_nd(%src: memref<24x32xf32>) {
gpu.func @store_nd(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%data = arith.constant dense<9.0> : vector<24x32xf32>
xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -116,11 +116,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_store_nd_1d
// CHECK-LABEL: store_nd_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32>
gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
gpu.func @store_nd_1d(%src: memref<64xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
%data = arith.constant dense<9.0> : vector<64xf32>
xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
@@ -129,7 +129,7 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_createNd_loadNd_storeNd
// CHECK-LABEL: createNd_loadNd_storeNd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
//CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
//CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
@@ -137,7 +137,7 @@ gpu.module @test {
//CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
//CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
//CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%data = arith.constant dense<9.0> : vector<24x32xf32>
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
@@ -148,23 +148,23 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_dpas
// CHECK-LABEL: dpas
// CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
//CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
//CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
//CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
//CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
//-----
// CHECK-LABEL: test_create_tdesc_vec
// CHECK-LABEL: create_tdesc_vec
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -177,10 +177,10 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_create_tdesc_step
// CHECK-LABEL: create_tdesc_step
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
@@ -190,11 +190,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_load
// CHECK-LABEL: load
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
gpu.func @test_load(%src: ui64) -> vector<32xf32> {
gpu.func @load(%src: ui64) -> vector<32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -212,11 +212,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_prefetch
// CHECK-LABEL: prefetch
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_prefetch(%src: ui64) {
gpu.func @prefetch(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
@@ -233,11 +233,11 @@ gpu.module @test {
//-----
// CHECK-LABEL: test_store
// CHECK-LABEL: store
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
gpu.func @test_store(%src: ui64) {
gpu.func @store(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -256,47 +256,129 @@ gpu.module @test {
}
//-----
// CHECK-LABEL: test_prefetch_load_store_update
// CHECK-LABEL: create_tdesc_step_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
}
gpu.func @test_prefetch_load_store_update(%src: ui64) {
//-----
// CHECK-LABEL: create_tdesc_step_chunk2
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
}
// CHECK-LABEL: create_tdesc_step_chunk3
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>> {
%step = arith.constant dense<8> : vector<16xindex>
%seq = vector.step : vector<16xindex>
%cst = arith.muli %seq, %step : vector<16xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
}
//-----
// CHECK-LABEL: load_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
%delta = arith.constant dense<[
32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 64,
128, 128, 128, 128, 128, 128, 128, 128,
128, 128, 128, 128, 128, 128, 128, 256
]> : vector<32xindex>
%new_tdesc = xegpu.update_offset %tdesc, %delta
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
%ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
gpu.return %ld : vector<4x32xf32>
}
%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
xegpu.store %st_vec, %tdesc, %mask:
vector<32xf32>,
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
vector<32xi1>
//-----
// CHECK-LABEL: store_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
gpu.func @store_chunk(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
%st_vec = arith.constant dense<1023.>: vector<4x32xf32>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
gpu.return
}
//-----
// CHECK-LABEL: prefetch_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
gpu.func @prefetch_chunk(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
gpu.return
}
//-----
// CHECK-LABEL: update_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
%delta = arith.constant dense<32>: vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
%new_tdesc = xegpu.update_offset %tdesc, %delta
: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
}
}

View File

@@ -19,6 +19,10 @@ using namespace mlir::xegpu;
namespace {
#define DEBUG_TYPE "test-xegpu-unroll"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
struct TestXeGPUUnrollingPatterns
: public PassWrapper<TestXeGPUUnrollingPatterns,
OperationPass<gpu::GPUModuleOp>> {
@@ -48,7 +52,9 @@ struct TestXeGPUUnrollingPatterns
options.setNativeShapeFn(
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
tdescTy = createNdOp.getType();
@@ -61,20 +67,7 @@ struct TestXeGPUUnrollingPatterns
tdescTy = loadNdOp.getTensorDescType();
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
tdescTy = storeNdOp.getTensorDescType();
}
if (auto layout = tdescTy.getLayoutAttr()) {
auto inst_data = layout.getInstData();
if (inst_data && layout.isSgLayout())
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
inst_data.asArrayRef().end());
}
}
if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
} else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
tdescTy = createOp.getType();
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
tdescTy = updateOp.getTensorDescType();
@@ -111,14 +104,40 @@ struct TestXeGPUUnrollingPatterns
Attribute encoding = tdescTy.getEncoding();
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
auto scatterAttr =
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
auto instData = layout.getInstData();
if (!instData.empty())
blockedChunkSize = instData.asArrayRef().back();
auto chunkSizeAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(ctx, 64), blockedChunkSize);
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
encoding = newEncoding;
}
}
if (layout) {
if (layout.getLaneLayout() == nullptr)
layout = xegpu::LayoutAttr();
else
layout = layout.dropInstData();
}
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
layout);
} else {
newTy = type.clone(tileShape, elemTy);
}