[MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)
Add support for load/store with chunk_size, which requires special consideration for the operand blocking since offests and masks are n-D and tensor are n+1-D. Support operations including create_tdesc, update_tdesc, load, store, and prefetch. --------- Co-authored-by: Adam Siemieniuk <adam.siemieniuk@intel.com>
This commit is contained in:
@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
xegpu::TensorDescType tdescTy = op.getType();
|
||||
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
|
||||
VectorType indiceVecTy = indiceVec.getType();
|
||||
|
||||
// check if the tensor descriptor type is a 1d vector type
|
||||
if (tdescTy.getRank() > 1)
|
||||
if (!tdescTy.isScattered())
|
||||
return failure();
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> targetIndiceShape(*targetShape);
|
||||
int64_t originalChunkSize = tdescTy.getChunkSize();
|
||||
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
|
||||
if (originalChunkSize > 1)
|
||||
targetIndiceShape.pop_back();
|
||||
|
||||
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
|
||||
|
||||
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
|
||||
VectorType indiceVecTy = indiceVec.getType();
|
||||
|
||||
SmallVector<Type> convertedIndiceTypes =
|
||||
getUnrolledTypes(indiceVecTy, *targetShape);
|
||||
getUnrolledTypes(indiceVecTy, targetIndiceShape);
|
||||
SmallVector<Value> convertedIndiceVec =
|
||||
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
|
||||
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
|
||||
|
||||
SmallVector<Value> newOps;
|
||||
for (auto indice : convertedIndiceVec) {
|
||||
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
|
||||
op.getSource(), indice);
|
||||
newOps.push_back(newOp);
|
||||
|
||||
// More indices is need when chunkSize > 1. Since a big load from one
|
||||
// address could be break into multiple small loads.
|
||||
if (originalChunkSize > 1) {
|
||||
int64_t blockedChunkSize = targetShape->back();
|
||||
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
||||
|
||||
for (auto [indice, indiceType] :
|
||||
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
|
||||
for (int64_t i = 0; i < numNewChunks; ++i) {
|
||||
// Compute the offset
|
||||
Value inc = rewriter.create<arith::ConstantIndexOp>(
|
||||
loc, i * blockedChunkSize);
|
||||
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
|
||||
Value offsetIndice =
|
||||
rewriter.create<arith::AddIOp>(loc, indice, incVec);
|
||||
|
||||
auto newOp = rewriter.create<xegpu::CreateDescOp>(
|
||||
loc, newTdescTy, op.getSource(), offsetIndice);
|
||||
|
||||
newOps.push_back(newOp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto indice : convertedIndiceVec) {
|
||||
auto newOp = rewriter.create<xegpu::CreateDescOp>(
|
||||
loc, newTdescTy, op.getSource(), indice);
|
||||
newOps.push_back(newOp);
|
||||
}
|
||||
}
|
||||
|
||||
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
|
||||
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
// check if the tensor descriptor type is a 1d vector type
|
||||
if (tdescTy.getRank() > 1)
|
||||
if (!tdescTy.isScattered())
|
||||
return failure();
|
||||
|
||||
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> targetMaskShape(*targetShape);
|
||||
int64_t originalChunkSize = tdescTy.getChunkSize();
|
||||
|
||||
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
||||
|
||||
Type elemTy = tdescTy.getElementType();
|
||||
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
||||
|
||||
@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
|
||||
SmallVector<Value> convertedTdescs = pack(
|
||||
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
||||
|
||||
SmallVector<Type> convertedMaskTypes =
|
||||
getUnrolledTypes(maskTy, *targetShape);
|
||||
SmallVector<Value> convertedMasks =
|
||||
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
|
||||
SmallVector<Type> convertedMaskTypes;
|
||||
SmallVector<Value> convertedMasks;
|
||||
|
||||
if (originalChunkSize > 1) {
|
||||
targetMaskShape.pop_back();
|
||||
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
||||
SmallVector<Value> convertedMasks1D = pack(
|
||||
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
|
||||
int64_t blockedChunkSize = targetShape->back();
|
||||
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
||||
|
||||
for (auto mask : convertedMasks1D) {
|
||||
for (int64_t i = 0; i < numNewChunks; ++i)
|
||||
convertedMasks.push_back(mask);
|
||||
}
|
||||
// This is to handle the transpose effect when chunkSize > 1.
|
||||
std::swap((*targetShape)[0], (*targetShape)[1]);
|
||||
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
||||
} else {
|
||||
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
||||
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
|
||||
loc, rewriter);
|
||||
}
|
||||
|
||||
SmallVector<Value> newOps;
|
||||
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
|
||||
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
|
||||
}
|
||||
|
||||
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
||||
|
||||
rewriter.replaceOp(op, castOp);
|
||||
return success();
|
||||
}
|
||||
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
|
||||
Location loc = op.getLoc();
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
// check if the tensor descriptor type is a 1d vector type
|
||||
if (tdescTy.getRank() > 1)
|
||||
if (!tdescTy.isScattered())
|
||||
return failure();
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
// check if the tensor descriptor type is a 1d vector type
|
||||
if (tdescTy.getRank() > 1)
|
||||
if (!tdescTy.isScattered())
|
||||
return failure();
|
||||
|
||||
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> convertedValTypes =
|
||||
getUnrolledTypes(valueTy, *targetShape);
|
||||
SmallVector<int64_t> targetIndiceShape(*targetShape);
|
||||
int64_t originalChunkSize = tdescTy.getChunkSize();
|
||||
|
||||
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
||||
|
||||
SmallVector<Type> convertedTdescTypes =
|
||||
getUnrolledTypes(tdescTy, *targetShape);
|
||||
|
||||
SmallVector<Value> convertedValues =
|
||||
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
||||
SmallVector<Value> convertedTdescs = pack(
|
||||
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
||||
|
||||
SmallVector<Type> convertedMaskTypes =
|
||||
getUnrolledTypes(maskTy, *targetShape);
|
||||
SmallVector<Value> convertedMasks =
|
||||
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
|
||||
SmallVector<Type> convertedMaskTypes;
|
||||
SmallVector<Value> convertedMasks;
|
||||
|
||||
if (originalChunkSize > 1) {
|
||||
int64_t blockedChunkSize = targetShape->back();
|
||||
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
||||
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
|
||||
SmallVector<Value> convertedMasks1D = pack(
|
||||
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
|
||||
|
||||
for (auto mask : convertedMasks1D) {
|
||||
for (int64_t i = 0; i < numNewChunks; ++i) {
|
||||
convertedMasks.push_back(mask);
|
||||
}
|
||||
}
|
||||
// This is to handle the transpose effect when chunkSize > 1.
|
||||
std::swap((*targetShape)[0], (*targetShape)[1]);
|
||||
|
||||
} else {
|
||||
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
|
||||
convertedMasks =
|
||||
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
|
||||
}
|
||||
|
||||
SmallVector<Type> convertedValTypes =
|
||||
getUnrolledTypes(valueTy, *targetShape);
|
||||
SmallVector<Value> convertedValues =
|
||||
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
||||
|
||||
for (size_t i = 0; i < convertedValues.size(); ++i) {
|
||||
Value v = convertedValues[i];
|
||||
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
|
||||
Location loc = op.getLoc();
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
// check if the tensor descriptor type is a 1d vector type
|
||||
if (tdescTy.getRank() > 1)
|
||||
if (tdescTy.getRank() > 2)
|
||||
return failure();
|
||||
|
||||
if (!tdescTy.isScattered())
|
||||
return failure();
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
|
||||
|
||||
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
|
||||
VectorType offsetVecTy = offsetVec.getType();
|
||||
SmallVector<Type> convertedOffsetTypes =
|
||||
getUnrolledTypes(offsetVecTy, *targetShape);
|
||||
SmallVector<Value> convertedOffsetVec =
|
||||
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
|
||||
|
||||
SmallVector<Type> convertedOffsetTypes;
|
||||
SmallVector<Value> convertedOffsetVec;
|
||||
SmallVector<Value> newOps;
|
||||
int64_t originalChunkSize = tdescTy.getChunkSize();
|
||||
if (originalChunkSize > 1) {
|
||||
SmallVector<int64_t> shape1D(targetShape->begin(),
|
||||
targetShape->end() - 1);
|
||||
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
|
||||
SmallVector<Value> convertedOffsetVec1D =
|
||||
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
|
||||
|
||||
int64_t blockedChunkSize = targetShape->back();
|
||||
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
||||
|
||||
for (auto offset : convertedOffsetVec1D) {
|
||||
for (int64_t i = 0; i < numNewChunks; ++i) {
|
||||
convertedOffsetVec.push_back(offset);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
|
||||
convertedOffsetVec =
|
||||
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
|
||||
}
|
||||
|
||||
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
|
||||
auto newOp =
|
||||
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
gpu.module @test {
|
||||
|
||||
// CHECK-LABEL: test_create_nd_tdesc
|
||||
// CHECK-LABEL: create_nd_tdesc
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
|
||||
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
|
||||
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
|
||||
@@ -10,31 +10,31 @@ gpu.module @test {
|
||||
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
|
||||
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
|
||||
// CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
|
||||
gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
|
||||
gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_create_nd_tdesc_1d
|
||||
// CHECK-LABEL: create_nd_tdesc_1d
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
|
||||
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
|
||||
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
|
||||
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
|
||||
// CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
|
||||
gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
|
||||
gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
|
||||
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_update_nd_tdesc
|
||||
// CHECK-LABEL: update_nd_tdesc
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
|
||||
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
|
||||
// CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
|
||||
gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
|
||||
gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
@@ -42,11 +42,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_update_nd_tdesc_1d
|
||||
// CHECK-LABEL: update_nd_tdesc_1d
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
|
||||
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
|
||||
// CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
|
||||
gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
|
||||
gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
|
||||
%update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
|
||||
gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
|
||||
@@ -54,11 +54,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_prefetch_nd_tdesc
|
||||
// CHECK-LABEL: prefetch_nd_tdesc
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
|
||||
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
|
||||
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
|
||||
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
|
||||
gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
gpu.return
|
||||
@@ -66,23 +66,23 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_prefetch_nd_tdesc_1d
|
||||
// CHECK-LABEL: prefetch_nd_tdesc_1d
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
|
||||
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
|
||||
// CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
|
||||
gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
|
||||
gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
|
||||
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
//-----
|
||||
// CHECK-LABEL: test_load_nd
|
||||
// CHECK-LABEL: load_nd
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
|
||||
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
|
||||
// CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
|
||||
// CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
|
||||
gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
|
||||
gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
|
||||
gpu.return %ld : vector<24x32xf32>
|
||||
@@ -90,12 +90,12 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_load_nd_1d
|
||||
// CHECK-LABEL: load_nd_1d
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
|
||||
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
|
||||
// CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
|
||||
// CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
|
||||
gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
|
||||
gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
|
||||
%data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
|
||||
gpu.return %data : vector<64xf32>
|
||||
@@ -103,11 +103,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_store_nd
|
||||
// CHECK-LABEL: store_nd
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
|
||||
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
|
||||
// CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
|
||||
gpu.func @test_store_nd(%src: memref<24x32xf32>) {
|
||||
gpu.func @store_nd(%src: memref<24x32xf32>) {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
%data = arith.constant dense<9.0> : vector<24x32xf32>
|
||||
xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
@@ -116,11 +116,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_store_nd_1d
|
||||
// CHECK-LABEL: store_nd_1d
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
|
||||
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
|
||||
// CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32>
|
||||
gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
|
||||
gpu.func @store_nd_1d(%src: memref<64xf32>) {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
|
||||
%data = arith.constant dense<9.0> : vector<64xf32>
|
||||
xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
|
||||
@@ -129,7 +129,7 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_createNd_loadNd_storeNd
|
||||
// CHECK-LABEL: createNd_loadNd_storeNd
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
|
||||
//CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
|
||||
//CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
|
||||
@@ -137,7 +137,7 @@ gpu.module @test {
|
||||
//CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
|
||||
//CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
|
||||
//CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
|
||||
gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
|
||||
gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
|
||||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
|
||||
%data = arith.constant dense<9.0> : vector<24x32xf32>
|
||||
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
|
||||
@@ -148,23 +148,23 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_dpas
|
||||
// CHECK-LABEL: dpas
|
||||
// CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
|
||||
//CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
|
||||
//CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
|
||||
//CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
|
||||
//CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
|
||||
gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
|
||||
gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
|
||||
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
|
||||
gpu.return %c : vector<32x32xf32>
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_create_tdesc_vec
|
||||
// CHECK-LABEL: create_tdesc_vec
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
|
||||
gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
@@ -177,10 +177,10 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_create_tdesc_step
|
||||
// CHECK-LABEL: create_tdesc_step
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
|
||||
gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
|
||||
%step = arith.constant dense<8> : vector<32xindex>
|
||||
%seq = vector.step : vector<32xindex>
|
||||
%cst = arith.muli %seq, %step : vector<32xindex>
|
||||
@@ -190,11 +190,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_load
|
||||
// CHECK-LABEL: load
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
|
||||
gpu.func @test_load(%src: ui64) -> vector<32xf32> {
|
||||
gpu.func @load(%src: ui64) -> vector<32xf32> {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
@@ -212,11 +212,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_prefetch
|
||||
// CHECK-LABEL: prefetch
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
gpu.func @test_prefetch(%src: ui64) {
|
||||
gpu.func @prefetch(%src: ui64) {
|
||||
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
@@ -233,11 +233,11 @@ gpu.module @test {
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_store
|
||||
// CHECK-LABEL: store
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
|
||||
gpu.func @test_store(%src: ui64) {
|
||||
gpu.func @store(%src: ui64) {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
@@ -256,47 +256,129 @@ gpu.module @test {
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK-LABEL: test_prefetch_load_store_update
|
||||
// CHECK-LABEL: create_tdesc_step_chunk
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
|
||||
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
|
||||
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
|
||||
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
|
||||
gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
|
||||
%step = arith.constant dense<8> : vector<32xindex>
|
||||
%seq = vector.step : vector<32xindex>
|
||||
%cst = arith.muli %seq, %step : vector<32xindex>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
|
||||
gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
|
||||
}
|
||||
|
||||
gpu.func @test_prefetch_load_store_update(%src: ui64) {
|
||||
//-----
|
||||
// CHECK-LABEL: create_tdesc_step_chunk2
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
|
||||
%step = arith.constant dense<8> : vector<32xindex>
|
||||
%seq = vector.step : vector<32xindex>
|
||||
%cst = arith.muli %seq, %step : vector<32xindex>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: create_tdesc_step_chunk3
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
|
||||
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
|
||||
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
|
||||
// CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>> {
|
||||
%step = arith.constant dense<8> : vector<16xindex>
|
||||
%seq = vector.step : vector<16xindex>
|
||||
%cst = arith.muli %seq, %step : vector<16xindex>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
}
|
||||
|
||||
//-----
|
||||
// CHECK-LABEL: load_chunk
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
|
||||
|
||||
gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
128, 136, 144, 152, 160, 168, 176, 184,
|
||||
192, 200, 208, 216, 224, 232, 240, 248
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
128, 136, 144, 152, 160, 168, 176, 184,
|
||||
192, 200, 208, 216, 224, 232, 240, 248
|
||||
]> : vector<32xindex>
|
||||
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
|
||||
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
|
||||
|
||||
%delta = arith.constant dense<[
|
||||
32, 32, 32, 32, 32, 32, 32, 32,
|
||||
32, 32, 32, 32, 32, 32, 32, 64,
|
||||
128, 128, 128, 128, 128, 128, 128, 128,
|
||||
128, 128, 128, 128, 128, 128, 128, 256
|
||||
]> : vector<32xindex>
|
||||
%new_tdesc = xegpu.update_offset %tdesc, %delta
|
||||
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
|
||||
|
||||
|
||||
%c17 = arith.constant 17: index
|
||||
%mask = vector.create_mask %c17: vector<32xi1>
|
||||
|
||||
%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
%ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
|
||||
|
||||
gpu.return %ld : vector<4x32xf32>
|
||||
}
|
||||
|
||||
%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
|
||||
xegpu.store %st_vec, %tdesc, %mask:
|
||||
vector<32xf32>,
|
||||
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
|
||||
vector<32xi1>
|
||||
|
||||
//-----
|
||||
// CHECK-LABEL: store_chunk
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
|
||||
gpu.func @store_chunk(%src: ui64) {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
128, 136, 144, 152, 160, 168, 176, 184,
|
||||
192, 200, 208, 216, 224, 232, 240, 248
|
||||
]> : vector<32xindex>
|
||||
|
||||
%c17 = arith.constant 17: index
|
||||
%mask = vector.create_mask %c17: vector<32xi1>
|
||||
|
||||
%st_vec = arith.constant dense<1023.>: vector<4x32xf32>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
|
||||
|
||||
gpu.return
|
||||
}
|
||||
|
||||
//-----
|
||||
// CHECK-LABEL: prefetch_chunk
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
gpu.func @prefetch_chunk(%src: ui64) {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
128, 136, 144, 152, 160, 168, 176, 184,
|
||||
192, 200, 208, 216, 224, 232, 240, 248
|
||||
]> : vector<32xindex>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
|
||||
gpu.return
|
||||
}
|
||||
|
||||
//-----
|
||||
// CHECK-LABEL: update_chunk
|
||||
// CHECK-SAME: [[arg0:%.+]]: ui64
|
||||
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
|
||||
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
|
||||
gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
|
||||
%cst = arith.constant dense<[
|
||||
0, 8, 16, 24, 32, 40, 48, 56,
|
||||
64, 72, 80, 88, 96, 104, 112, 120,
|
||||
128, 136, 144, 152, 160, 168, 176, 184,
|
||||
192, 200, 208, 216, 224, 232, 240, 248
|
||||
]> : vector<32xindex>
|
||||
%delta = arith.constant dense<32>: vector<32xindex>
|
||||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
|
||||
%new_tdesc = xegpu.update_offset %tdesc, %delta
|
||||
: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
|
||||
|
||||
gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,10 @@ using namespace mlir::xegpu;
|
||||
|
||||
namespace {
|
||||
|
||||
#define DEBUG_TYPE "test-xegpu-unroll"
|
||||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
||||
|
||||
struct TestXeGPUUnrollingPatterns
|
||||
: public PassWrapper<TestXeGPUUnrollingPatterns,
|
||||
OperationPass<gpu::GPUModuleOp>> {
|
||||
@@ -48,7 +52,9 @@ struct TestXeGPUUnrollingPatterns
|
||||
options.setNativeShapeFn(
|
||||
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
|
||||
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
|
||||
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
|
||||
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
|
||||
xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
|
||||
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
|
||||
xegpu::TensorDescType tdescTy;
|
||||
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
|
||||
tdescTy = createNdOp.getType();
|
||||
@@ -61,20 +67,7 @@ struct TestXeGPUUnrollingPatterns
|
||||
tdescTy = loadNdOp.getTensorDescType();
|
||||
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
|
||||
tdescTy = storeNdOp.getTensorDescType();
|
||||
}
|
||||
|
||||
if (auto layout = tdescTy.getLayoutAttr()) {
|
||||
auto inst_data = layout.getInstData();
|
||||
if (inst_data && layout.isSgLayout())
|
||||
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
|
||||
inst_data.asArrayRef().end());
|
||||
}
|
||||
}
|
||||
|
||||
if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
|
||||
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
|
||||
xegpu::TensorDescType tdescTy;
|
||||
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
|
||||
} else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
|
||||
tdescTy = createOp.getType();
|
||||
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
|
||||
tdescTy = updateOp.getTensorDescType();
|
||||
@@ -111,14 +104,40 @@ struct TestXeGPUUnrollingPatterns
|
||||
Attribute encoding = tdescTy.getEncoding();
|
||||
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
|
||||
tdescTy.getLayout());
|
||||
|
||||
// If the encoding is a ScatterTensorDescAttr, we need to
|
||||
// potentially adjust the chunk size based on the inst_data.
|
||||
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
|
||||
auto scatterAttr =
|
||||
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
|
||||
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
|
||||
|
||||
if (chunkSize > 1) {
|
||||
int64_t blockedChunkSize = chunkSize;
|
||||
auto instData = layout.getInstData();
|
||||
if (!instData.empty())
|
||||
blockedChunkSize = instData.asArrayRef().back();
|
||||
|
||||
auto chunkSizeAttr = mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(ctx, 64), blockedChunkSize);
|
||||
|
||||
// To create a new attribute with a different chunk_size:
|
||||
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
|
||||
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
|
||||
|
||||
encoding = newEncoding;
|
||||
}
|
||||
}
|
||||
if (layout) {
|
||||
if (layout.getLaneLayout() == nullptr)
|
||||
layout = xegpu::LayoutAttr();
|
||||
else
|
||||
layout = layout.dropInstData();
|
||||
}
|
||||
|
||||
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
|
||||
layout);
|
||||
|
||||
} else {
|
||||
newTy = type.clone(tileShape, elemTy);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user