1157 lines
44 KiB
Python
1157 lines
44 KiB
Python
import numpy as np
|
|
from mlir import ir
|
|
from mlir.dialects import arith
|
|
from mlir.dialects import func
|
|
from mlir.dialects import gpu
|
|
from mlir.dialects import memref
|
|
from mlir.dialects import nvgpu
|
|
from mlir.dialects import nvvm
|
|
from mlir.dialects import llvm
|
|
from mlir.dialects import builtin
|
|
from mlir.dialects import scf
|
|
from mlir.dialects import vector
|
|
from mlir.extras import types as T
|
|
|
|
TMA_LAST_DIM_F16 = 64 # 128B flaot16
|
|
WARP_SIZE = 32
|
|
WARP_GROUP_SIZE = WARP_SIZE * 4
|
|
|
|
PRODUCER_REGISTER_SIZE = 40
|
|
CONSUMER_REGISTER_SIZE = 232
|
|
|
|
PRODUCER_PRIMARY_THREAD = 128
|
|
CONSUMER_PRIMARY_THREAD = 0
|
|
|
|
# C++ uses this value to understand whether it's dynamic or not.
|
|
MLIR_DYNAMIC = -9223372036854775808
|
|
|
|
DEBUG = False
|
|
|
|
|
|
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
|
|
if not DEBUG and not forcePrint:
|
|
return
|
|
type_formats = []
|
|
for arg in args:
|
|
ty_format = None
|
|
if ir.IndexType.isinstance(arg.type):
|
|
ty_format = "%llu"
|
|
if ir.IntegerType.isinstance(arg.type):
|
|
width = ir.IntegerType(arg.type).width
|
|
if width == 64:
|
|
ty_format = "%llu"
|
|
elif width == 32:
|
|
ty_format = "%d"
|
|
elif width == 1:
|
|
ty_format = "%i"
|
|
if ir.F32Type.isinstance(arg.type):
|
|
ty_format = "%f"
|
|
if ty_format is None:
|
|
raise NotImplementedError(arg.type)
|
|
type_formats.append(ty_format)
|
|
if threadNumber != -1:
|
|
tidx = gpu.thread_id(gpu.Dimension.x)
|
|
predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
|
|
scf.yield_([])
|
|
if_op = scf.IfOp(predicate)
|
|
with ir.InsertionPoint(if_op.then_block):
|
|
gpu.printf(fmt.format(*type_formats) + "\n", args)
|
|
scf.yield_([])
|
|
|
|
|
|
def get_type_size(ty):
|
|
if ir.FloatType.isinstance(ty):
|
|
return ir.FloatType(ty).width // 8
|
|
if ir.IntegerType.isinstance(ty):
|
|
return ir.IntegerType(ty).width // 8
|
|
raise NotImplementedError(ty)
|
|
|
|
|
|
def get_mlir_ty(dtype):
|
|
if dtype == np.float16:
|
|
return T.f16()
|
|
if dtype == np.float32:
|
|
return T.f32()
|
|
if dtype == np.float64:
|
|
return T.f64()
|
|
if dtype == np.int32:
|
|
return T.i32()
|
|
if dtype == np.int64:
|
|
return T.i64()
|
|
raise NotImplementedError(dtype)
|
|
|
|
|
|
def c(value, ty=None):
|
|
ty = T.index() if ty is None else ty
|
|
return arith.constant(ty, value)
|
|
|
|
|
|
def make_kernel_name(
|
|
input_type=np.float16,
|
|
output_type=np.float32,
|
|
M=4096,
|
|
N=4096,
|
|
K=4096,
|
|
BLOCK_M=128,
|
|
BLOCK_N=128,
|
|
BLOCK_K=128,
|
|
num_stages=3,
|
|
use_warp_specialization=False,
|
|
):
|
|
kernelName = "warpspecialized" if use_warp_specialization else "multistage"
|
|
return (
|
|
kernelName
|
|
+ "_"
|
|
+ str(M)
|
|
+ "x"
|
|
+ str(N)
|
|
+ "x"
|
|
+ str(K)
|
|
+ "_"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(BLOCK_N)
|
|
+ "x"
|
|
+ str(BLOCK_K)
|
|
+ "_"
|
|
+ str(num_stages)
|
|
)
|
|
|
|
|
|
def generate_matmul_ws(
|
|
input_type=np.float16,
|
|
output_type=np.float32,
|
|
M=4096,
|
|
N=4096,
|
|
K=4096,
|
|
BLOCK_M=128,
|
|
BLOCK_N=128,
|
|
BLOCK_K=128,
|
|
num_stages=3,
|
|
):
|
|
# Limitaitons for now
|
|
assert input_type == np.float16
|
|
assert output_type == np.float32
|
|
assert BLOCK_M == 128
|
|
assert BLOCK_N == 128
|
|
assert BLOCK_K == 64
|
|
assert M % BLOCK_M == 0
|
|
assert N % BLOCK_N == 0
|
|
assert K % BLOCK_K == 0
|
|
|
|
module = ir.Module.create()
|
|
token_ty = ir.Type.parse("!gpu.async.token")
|
|
a_elem_ty = get_mlir_ty(input_type)
|
|
b_elem_ty = get_mlir_ty(input_type)
|
|
c_elem_ty = get_mlir_ty(output_type)
|
|
a_ty = ir.MemRefType.get([M, K], a_elem_ty)
|
|
b_ty = ir.MemRefType.get((K, N), b_elem_ty)
|
|
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
|
|
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
|
|
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
|
|
b_tile_shape = (BLOCK_K, BLOCK_N)
|
|
txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
|
|
a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
|
|
)
|
|
smem_space_str = "#gpu.address_space<workgroup>"
|
|
smem_space = ir.Attribute.parse(smem_space_str)
|
|
mbar_ty = ir.Type.parse(
|
|
"!nvgpu.mbarrier.group<memorySpace = "
|
|
+ str(smem_space)
|
|
+ ", num_barriers = "
|
|
+ str(num_stages)
|
|
+ ">"
|
|
)
|
|
a_tma_desc_ty = ir.Type.parse(
|
|
"!nvgpu.tensormap.descriptor<tensor = memref<"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(TMA_LAST_DIM_F16)
|
|
+ "x"
|
|
+ str(a_elem_ty)
|
|
+ ", "
|
|
+ str(smem_space)
|
|
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
|
|
)
|
|
b_tma_desc_ty = ir.Type.parse(
|
|
"!nvgpu.tensormap.descriptor<tensor = memref<"
|
|
+ str(BLOCK_K)
|
|
+ "x"
|
|
+ str(TMA_LAST_DIM_F16)
|
|
+ "x"
|
|
+ str(b_elem_ty)
|
|
+ ", "
|
|
+ str(smem_space)
|
|
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
|
|
)
|
|
acc_ty = ir.Type.parse(
|
|
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(BLOCK_N)
|
|
+ "x"
|
|
+ str(c_elem_ty)
|
|
+ ">>"
|
|
)
|
|
a_wgmma_ty = ir.Type.parse(
|
|
"!nvgpu.warpgroup.descriptor<tensor=memref<"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(BLOCK_K)
|
|
+ "x"
|
|
+ str(a_elem_ty)
|
|
+ ", "
|
|
+ smem_space_str
|
|
+ ">>"
|
|
)
|
|
b_wgmma_ty = ir.Type.parse(
|
|
"!nvgpu.warpgroup.descriptor<tensor=memref<"
|
|
+ str(BLOCK_K)
|
|
+ "x"
|
|
+ str(BLOCK_N)
|
|
+ "x"
|
|
+ str(a_elem_ty)
|
|
+ ", "
|
|
+ smem_space_str
|
|
+ ">>"
|
|
)
|
|
kernelName = make_kernel_name(
|
|
input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
|
|
)
|
|
with ir.InsertionPoint(module.body):
|
|
fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
|
|
with ir.InsertionPoint(fop.add_entry_block()):
|
|
a_host = fop.arguments[0]
|
|
b_host = fop.arguments[1]
|
|
c_host = fop.arguments[2]
|
|
lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
|
|
rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
|
|
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
|
|
smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
|
|
smem_size = max(smem_size_input, smem_size_output)
|
|
|
|
# Step 1. Allocate device memory and memcpy
|
|
t1 = gpu.wait(token_ty, [])
|
|
a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
|
|
b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
|
|
c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
|
|
t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
|
|
t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
|
|
t7 = gpu.wait(token_ty, [t6])
|
|
|
|
# Step 2. Create TMA Descriptors
|
|
tma_specs = [
|
|
(a_device, a_tma_desc_ty, a_tma_shape),
|
|
(b_device, b_tma_desc_ty, b_tma_shape),
|
|
]
|
|
tma_descs = []
|
|
for x_device, tensor_map_ty, tile_shape in tma_specs:
|
|
x_unranked = memref.cast(
|
|
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
|
|
)
|
|
tma_descs.append(
|
|
nvgpu.TmaCreateDescriptorOp(
|
|
tensor_map_ty, x_unranked, map(c, tile_shape)
|
|
).result
|
|
)
|
|
a_tma_desc, b_tma_desc = tma_descs
|
|
|
|
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
|
|
cta_m = M // BLOCK_M
|
|
cta_n = N // BLOCK_N
|
|
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
|
|
grid = (cta_m, cta_n, 1)
|
|
block = (WARP_GROUP_SIZE * 2, 1, 1)
|
|
launch_op = gpu.LaunchOp(
|
|
token_ty,
|
|
[t7],
|
|
*map(c, grid),
|
|
*map(c, block),
|
|
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
|
|
)
|
|
launch_op.body.blocks.append(*([T.index()] * 12))
|
|
with ir.InsertionPoint(launch_op.body.blocks[0]):
|
|
# GPU Step 0. This is need for vectorized ld/st
|
|
memref.assume_alignment(c_device, 16)
|
|
dynamic_smem = gpu.dynamic_shared_memory(
|
|
ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
|
|
)
|
|
ticks = c(10000000)
|
|
|
|
# GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
|
|
tidx = gpu.thread_id(gpu.Dimension.x)
|
|
wgPrimaryThread = arith.cmpi(
|
|
arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
|
|
)
|
|
warp_id = arith.divui(tidx, c(32))
|
|
warpgroup_id = arith.divui(warp_id, c(4))
|
|
is_producer = arith.cmpi(
|
|
arith.CmpIPredicate.eq,
|
|
warpgroup_id,
|
|
c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
|
|
)
|
|
is_consumer = arith.cmpi(
|
|
arith.CmpIPredicate.eq,
|
|
warpgroup_id,
|
|
c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
|
|
)
|
|
producerPrimaryThread = arith.cmpi(
|
|
arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
|
|
)
|
|
consumerPrimaryThread = arith.cmpi(
|
|
arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
|
|
)
|
|
bidx = gpu.block_id(gpu.Dimension.x)
|
|
bidy = gpu.block_id(gpu.Dimension.y)
|
|
dimX = arith.muli(bidx, c(BLOCK_M))
|
|
dimY = arith.muli(bidy, c(BLOCK_N))
|
|
|
|
# GPU Step 2. Initialize mbarrier groups
|
|
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
|
|
mbarDONE = nvgpu.mbarrier_create(mbar_ty)
|
|
for i in range(num_stages):
|
|
nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
|
|
nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
|
|
gpu.barrier()
|
|
|
|
# GPU Step 3. Prefetch TMA descriptors
|
|
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
|
|
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
|
|
|
|
ns = num_stages if num_stages == 1 else num_stages - 1
|
|
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
|
|
with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
|
|
# Step 5.1. Reduce register size
|
|
nvvm.setmaxregister(
|
|
PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
|
|
)
|
|
|
|
# Step 5.2. TMA Main Loop
|
|
for_op = scf.ForOp(
|
|
c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
|
|
)
|
|
with ir.InsertionPoint(for_op.body):
|
|
phaseParity = for_op.inner_iter_args[0]
|
|
iv = for_op.induction_variable
|
|
stage = arith.remui(iv, c(num_stages))
|
|
|
|
# Step 5.2.1. Wait mbarDONE
|
|
debug_print(
|
|
"[prod] iv={} | mbarDONE[{}] try_wait phase={}",
|
|
iv,
|
|
stage,
|
|
phaseParity,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
nvgpu.MBarrierTryWaitParityOp(
|
|
mbarDONE, phaseParity, ticks, mbarId=stage
|
|
)
|
|
debug_print(
|
|
"[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]",
|
|
iv,
|
|
stage,
|
|
phaseParity,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
|
|
phaseParity = arith.select(
|
|
p,
|
|
arith.xori(phaseParity, arith.constant(T.bool(), 1)),
|
|
phaseParity,
|
|
)
|
|
|
|
# Step 5.2.2. Load TMA
|
|
a_offset = arith.muli(stage, c(lhs_tile_bytes))
|
|
a_tma_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
a_tma_shape, a_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
a_offset,
|
|
[],
|
|
)
|
|
b_offset = arith.addi(
|
|
arith.muli(stage, c(rhs_tile_bytes)),
|
|
c(lhs_tile_bytes * num_stages),
|
|
)
|
|
b_tma_slice_1 = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tma_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset,
|
|
[],
|
|
)
|
|
b_offset2 = arith.addi(
|
|
b_offset,
|
|
c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
|
|
)
|
|
b_tma_slice_2 = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tma_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset2,
|
|
[],
|
|
)
|
|
debug_print(
|
|
"[prod] a_offset={} b_offset={} b_offset2={}",
|
|
a_offset,
|
|
b_offset,
|
|
b_offset2,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
coord = arith.muli(c(64), iv)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
a_tma_slice,
|
|
mbarTMA,
|
|
a_tma_desc,
|
|
coordinates=[coord, dimX],
|
|
mbarId=stage,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
b_tma_slice_1,
|
|
mbarTMA,
|
|
b_tma_desc,
|
|
coordinates=[dimY, coord],
|
|
mbarId=stage,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
dimY2 = arith.addi(dimY, c(64))
|
|
nvgpu.TmaAsyncLoadOp(
|
|
b_tma_slice_2,
|
|
mbarTMA,
|
|
b_tma_desc,
|
|
coordinates=[dimY2, coord],
|
|
mbarId=stage,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
|
|
# Step 5.2.3. Arrive mbarTMA
|
|
debug_print(
|
|
"[prod] iv={} | mbarTMA[{}] arrive",
|
|
iv,
|
|
stage,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
nvgpu.mbarrier_arrive_expect_tx(
|
|
mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
|
|
)
|
|
debug_print(
|
|
"[prod] iv={} | mbarTMA[{}] arrive [done]",
|
|
iv,
|
|
stage,
|
|
predicate=producerPrimaryThread,
|
|
)
|
|
scf.yield_([phaseParity])
|
|
scf.yield_([])
|
|
|
|
# GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
|
|
if_op = scf.IfOp(is_consumer)
|
|
with ir.InsertionPoint(if_op.then_block):
|
|
# Step 6.1. Increase register size
|
|
nvvm.setmaxregister(
|
|
CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
|
|
)
|
|
|
|
# GPU Step 6.2. Initialize MMA registers
|
|
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
|
|
|
|
# Step 6.3. MMA Main Loop
|
|
for_op = scf.ForOp(
|
|
c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
|
|
)
|
|
with ir.InsertionPoint(for_op.body):
|
|
# Step 6.3.1. Wait mbar1
|
|
phaseParity = for_op.inner_iter_args[1]
|
|
iv = for_op.induction_variable
|
|
stage = arith.remui(iv, c(num_stages))
|
|
debug_print(
|
|
"[cons] iv={} | mbarTMA[{}] try_wait phase={}",
|
|
iv,
|
|
stage,
|
|
phaseParity,
|
|
predicate=consumerPrimaryThread,
|
|
)
|
|
nvgpu.MBarrierTryWaitParityOp(
|
|
mbarTMA, phaseParity, ticks, mbarId=stage
|
|
)
|
|
debug_print(
|
|
"[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]",
|
|
iv,
|
|
stage,
|
|
phaseParity,
|
|
predicate=consumerPrimaryThread,
|
|
)
|
|
|
|
# Step 6.3.2. Create WGMMA Descriptors
|
|
a_offset = arith.muli(stage, c(lhs_tile_bytes))
|
|
a_tile_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
a_tile_shape, a_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
a_offset,
|
|
[],
|
|
)
|
|
b_offset = arith.addi(
|
|
arith.muli(stage, c(rhs_tile_bytes)),
|
|
c(lhs_tile_bytes * num_stages),
|
|
)
|
|
b_tile_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tile_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset,
|
|
[],
|
|
)
|
|
debug_print(
|
|
"[cons] a_offset={} b_offset={}",
|
|
a_offset,
|
|
b_offset,
|
|
predicate=consumerPrimaryThread,
|
|
)
|
|
da = nvgpu.WarpgroupGenerateDescriptorOp(
|
|
a_wgmma_ty, a_tile_slice, a_tma_desc
|
|
)
|
|
db = nvgpu.WarpgroupGenerateDescriptorOp(
|
|
b_wgmma_ty, b_tile_slice, b_tma_desc
|
|
)
|
|
|
|
# Step 6.3.3. MMA
|
|
carry_acc = for_op.inner_iter_args[0]
|
|
new_acc = nvgpu.WarpgroupMmaOp(
|
|
acc.type, da, db, carry_acc, transposeB=True
|
|
)
|
|
|
|
# Step 6.3.4. Arrive mbarDONE
|
|
if num_stages == 1:
|
|
p_arrive = consumerPrimaryThread
|
|
else:
|
|
p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
|
|
p_arrive = arith.andi(consumerPrimaryThread, p1)
|
|
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
|
|
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
|
|
barId = arith.select(
|
|
p, c(num_stages - 1), arith.subi(stage, c(1))
|
|
)
|
|
debug_print(
|
|
"[cons] iv={} | mbarDONE[{}] arrive ",
|
|
iv,
|
|
barId,
|
|
predicate=consumerPrimaryThread,
|
|
)
|
|
nvgpu.mbarrier_arrive(
|
|
ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
|
|
)
|
|
debug_print(
|
|
"[cons] iv={} | mbarDONE[{}] arrive [done]",
|
|
iv,
|
|
barId,
|
|
predicate=consumerPrimaryThread,
|
|
)
|
|
scf.yield_([])
|
|
|
|
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
|
|
phaseParity = arith.select(
|
|
p,
|
|
arith.xori(phaseParity, arith.constant(T.bool(), 1)),
|
|
phaseParity,
|
|
)
|
|
|
|
# Step 6.3.5. Yield
|
|
scf.yield_([new_acc, phaseParity])
|
|
|
|
# Step 6.3. Wait All WGMMA
|
|
nvvm.WgmmaWaitGroupSyncOp(0)
|
|
|
|
with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
|
|
barId = c((K // BLOCK_K) % num_stages)
|
|
nvgpu.mbarrier_arrive(
|
|
ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
|
|
)
|
|
scf.yield_([])
|
|
|
|
# Step 6.4. Epilogue (registers --> shared memory)
|
|
acc_smem_ty = ir.MemRefType.get(
|
|
(BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
|
|
)
|
|
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
|
|
debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
|
|
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
|
|
scf.yield_([])
|
|
gpu.barrier()
|
|
|
|
# GPU Step 9. Epilogue (shared memory --> global memory)
|
|
fd = ir.MemRefType.get(
|
|
[BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
|
|
)
|
|
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
|
|
rty = ir.MemRefType.get(
|
|
(BLOCK_M, BLOCK_N),
|
|
c_elem_ty,
|
|
ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
|
|
)
|
|
c_device_per_block = memref.SubViewOp(
|
|
rty,
|
|
c_device,
|
|
[dimX, dimY],
|
|
[],
|
|
[],
|
|
[MLIR_DYNAMIC, MLIR_DYNAMIC],
|
|
[BLOCK_M, BLOCK_N],
|
|
[1, 1],
|
|
)
|
|
vlen = 1
|
|
for_op = scf.ForOp(
|
|
tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
|
|
)
|
|
with ir.InsertionPoint(for_op.body):
|
|
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
|
|
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
|
|
vdata = vector.load(
|
|
ir.VectorType.get((vlen,), c_elem_ty),
|
|
collapsed_smem,
|
|
[for_op.induction_variable],
|
|
)
|
|
vector.store(vdata, c_device_per_block, [x, y])
|
|
scf.yield_([])
|
|
|
|
gpu.terminator()
|
|
|
|
# Step 4. Copy back to host
|
|
t8 = gpu.wait(token_ty, [launch_op])
|
|
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
|
|
gpu.dealloc(token_ty, [t8], a_device)
|
|
gpu.dealloc(token_ty, [t8], b_device)
|
|
gpu.wait(token_ty, [t9])
|
|
gpu.dealloc(token_ty, [t8], c_device)
|
|
func.ReturnOp([])
|
|
|
|
fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
|
module.operation.verify()
|
|
return module
|
|
|
|
|
|
def generate_matmul_multistage(
|
|
input_type=np.float16,
|
|
output_type=np.float32,
|
|
M=4096,
|
|
N=4096,
|
|
K=4096,
|
|
BLOCK_M=128,
|
|
BLOCK_N=128,
|
|
BLOCK_K=64,
|
|
num_stages=3,
|
|
):
|
|
# Limitaitons for now
|
|
assert input_type == np.float16
|
|
assert output_type == np.float32
|
|
assert BLOCK_M == 128
|
|
assert BLOCK_N == 128
|
|
assert BLOCK_K == 64
|
|
assert M % BLOCK_M == 0
|
|
assert N % BLOCK_N == 0
|
|
assert K % BLOCK_K == 0
|
|
|
|
module = ir.Module.create()
|
|
token_ty = ir.Type.parse("!gpu.async.token")
|
|
a_elem_ty = get_mlir_ty(input_type)
|
|
b_elem_ty = get_mlir_ty(input_type)
|
|
c_elem_ty = get_mlir_ty(output_type)
|
|
a_ty = ir.MemRefType.get([M, K], a_elem_ty)
|
|
b_ty = ir.MemRefType.get((K, N), b_elem_ty)
|
|
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
|
|
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
|
|
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
|
|
b_tile_shape = (BLOCK_K, BLOCK_N)
|
|
txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
|
|
a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
|
|
)
|
|
smem_space_str = "#gpu.address_space<workgroup>"
|
|
smem_space = ir.Attribute.parse(smem_space_str)
|
|
mbar_ty = ir.Type.parse(
|
|
"!nvgpu.mbarrier.group<memorySpace = "
|
|
+ str(smem_space)
|
|
+ ", num_barriers = "
|
|
+ str(num_stages)
|
|
+ ">"
|
|
)
|
|
a_tma_desc_ty = ir.Type.parse(
|
|
"!nvgpu.tensormap.descriptor<tensor = memref<"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(TMA_LAST_DIM_F16)
|
|
+ "x"
|
|
+ str(a_elem_ty)
|
|
+ ", "
|
|
+ str(smem_space)
|
|
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
|
|
)
|
|
b_tma_desc_ty = ir.Type.parse(
|
|
"!nvgpu.tensormap.descriptor<tensor = memref<"
|
|
+ str(BLOCK_K)
|
|
+ "x"
|
|
+ str(TMA_LAST_DIM_F16)
|
|
+ "x"
|
|
+ str(b_elem_ty)
|
|
+ ", "
|
|
+ str(smem_space)
|
|
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
|
|
)
|
|
acc_ty = ir.Type.parse(
|
|
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(BLOCK_N)
|
|
+ "x"
|
|
+ str(c_elem_ty)
|
|
+ ">>"
|
|
)
|
|
a_wgmma_ty = ir.Type.parse(
|
|
"!nvgpu.warpgroup.descriptor<tensor=memref<"
|
|
+ str(BLOCK_M)
|
|
+ "x"
|
|
+ str(BLOCK_K)
|
|
+ "x"
|
|
+ str(a_elem_ty)
|
|
+ ", "
|
|
+ smem_space_str
|
|
+ ">>"
|
|
)
|
|
b_wgmma_ty = ir.Type.parse(
|
|
"!nvgpu.warpgroup.descriptor<tensor=memref<"
|
|
+ str(BLOCK_K)
|
|
+ "x"
|
|
+ str(BLOCK_N)
|
|
+ "x"
|
|
+ str(a_elem_ty)
|
|
+ ", "
|
|
+ smem_space_str
|
|
+ ">>"
|
|
)
|
|
|
|
with ir.InsertionPoint(module.body):
|
|
kernelName = make_kernel_name(
|
|
input_type,
|
|
output_type,
|
|
M,
|
|
N,
|
|
K,
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
BLOCK_K,
|
|
num_stages,
|
|
False,
|
|
)
|
|
fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
|
|
with ir.InsertionPoint(fop.add_entry_block()):
|
|
a_host = fop.arguments[0]
|
|
b_host = fop.arguments[1]
|
|
c_host = fop.arguments[2]
|
|
lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
|
|
rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
|
|
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
|
|
smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
|
|
smem_size = max(smem_size_input, smem_size_output)
|
|
|
|
# Step 1. Allocate device memory and memcpy
|
|
t1 = gpu.wait(token_ty, [])
|
|
a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
|
|
b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
|
|
c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
|
|
t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
|
|
t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
|
|
t7 = gpu.wait(token_ty, [t6])
|
|
|
|
# Step 2. Create TMA Descriptors
|
|
tma_specs = [
|
|
(a_device, a_tma_desc_ty, a_tma_shape),
|
|
(b_device, b_tma_desc_ty, b_tma_shape),
|
|
]
|
|
tma_descs = []
|
|
for x_device, tensor_map_ty, tile_shape in tma_specs:
|
|
x_unranked = memref.cast(
|
|
ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
|
|
)
|
|
tma_descs.append(
|
|
nvgpu.TmaCreateDescriptorOp(
|
|
tensor_map_ty, x_unranked, map(c, tile_shape)
|
|
).result
|
|
)
|
|
a_tma_desc, b_tma_desc = tma_descs
|
|
|
|
# Step 3. Launch Kernel with 1 Warpgroup
|
|
cta_m = M // BLOCK_M
|
|
cta_n = N // BLOCK_N
|
|
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
|
|
grid = (cta_m, cta_n, 1)
|
|
block = (WARP_GROUP_SIZE, 1, 1)
|
|
launch_op = gpu.LaunchOp(
|
|
token_ty,
|
|
[t7],
|
|
*map(c, grid),
|
|
*map(c, block),
|
|
dynamicSharedMemorySize=c(smem_size, ty=T.i32())
|
|
)
|
|
launch_op.body.blocks.append(*([T.index()] * 12))
|
|
with ir.InsertionPoint(launch_op.body.blocks[0]):
|
|
# GPU Step 0. Bootstrapping
|
|
memref.assume_alignment(c_device, 16)
|
|
dynamic_smem = gpu.dynamic_shared_memory(
|
|
ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
|
|
)
|
|
ticks = c(10000000)
|
|
tidx = gpu.thread_id(gpu.Dimension.x)
|
|
primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
|
|
warpId = arith.divui(tidx, c(32))
|
|
bidx = gpu.block_id(gpu.Dimension.x)
|
|
bidy = gpu.block_id(gpu.Dimension.y)
|
|
dimX = arith.muli(bidx, c(BLOCK_M))
|
|
dimY = arith.muli(bidy, c(BLOCK_N))
|
|
|
|
# GPU Step 1. Initialize mbarrier groups
|
|
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
|
|
for i in range(num_stages):
|
|
nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
|
|
gpu.barrier()
|
|
|
|
# GPU Step 2. Prefetch TMA descriptors
|
|
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
|
|
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
|
|
|
|
# GPU Step 3. Prologue (global memory --> shared memory)
|
|
ns = num_stages if num_stages == 1 else num_stages - 1
|
|
for_op = scf.ForOp(c(0), c(ns), c(1))
|
|
with ir.InsertionPoint(for_op.body):
|
|
iv = for_op.induction_variable
|
|
|
|
# Step 3.1. Calculate offsets
|
|
a_offset = arith.muli(iv, c(lhs_tile_bytes))
|
|
a_tma_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
a_tma_shape, a_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
a_offset,
|
|
[],
|
|
)
|
|
b_offset = arith.addi(
|
|
arith.muli(iv, c(rhs_tile_bytes)),
|
|
c(lhs_tile_bytes * num_stages),
|
|
)
|
|
b_tma_slice_1 = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tma_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset,
|
|
[],
|
|
)
|
|
b_offset2 = arith.addi(
|
|
b_offset,
|
|
c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
|
|
)
|
|
b_tma_slice_2 = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tma_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset2,
|
|
[],
|
|
)
|
|
|
|
# Step 3.2. TMA Load
|
|
coord = arith.muli(c(64), iv)
|
|
dimY2 = arith.addi(dimY, c(64))
|
|
debug_print(
|
|
"[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
|
|
a_offset,
|
|
b_offset,
|
|
b_offset2,
|
|
coord,
|
|
dimX,
|
|
dimY,
|
|
coord,
|
|
predicate=primaryThread,
|
|
)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
a_tma_slice,
|
|
mbarTMA,
|
|
a_tma_desc,
|
|
coordinates=[coord, dimX],
|
|
mbarId=iv,
|
|
predicate=primaryThread,
|
|
)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
b_tma_slice_1,
|
|
mbarTMA,
|
|
b_tma_desc,
|
|
coordinates=[dimY, coord],
|
|
mbarId=iv,
|
|
predicate=primaryThread,
|
|
)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
b_tma_slice_2,
|
|
mbarTMA,
|
|
b_tma_desc,
|
|
coordinates=[dimY2, coord],
|
|
mbarId=iv,
|
|
predicate=primaryThread,
|
|
)
|
|
|
|
# Step 3.2. mbarTMA arrive
|
|
debug_print(
|
|
"[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
|
|
)
|
|
nvgpu.mbarrier_arrive_expect_tx(
|
|
mbarTMA, c(txcount), iv, predicate=primaryThread
|
|
)
|
|
debug_print(
|
|
"[Prologue] mbarTMA[{}] arrive [done]",
|
|
iv,
|
|
predicate=primaryThread,
|
|
)
|
|
scf.yield_([])
|
|
|
|
# GPU Step 4. Main Loop
|
|
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
|
|
for_op = scf.ForOp(
|
|
c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
|
|
)
|
|
with ir.InsertionPoint(for_op.body):
|
|
# Step 4.1. Wait mbarTMA
|
|
phaseParity = for_op.inner_iter_args[1]
|
|
iv = for_op.induction_variable
|
|
stage = arith.remui(iv, c(num_stages))
|
|
debug_print(
|
|
"[MainLoop] mbarTMA[{}] try_wait phase={}",
|
|
stage,
|
|
phaseParity,
|
|
predicate=primaryThread,
|
|
)
|
|
nvgpu.MBarrierTryWaitParityOp(
|
|
mbarTMA, phaseParity, ticks, mbarId=stage
|
|
)
|
|
debug_print(
|
|
"[MainLoop] mbarTMA[{}] try_wait phase={} [done]",
|
|
stage,
|
|
phaseParity,
|
|
predicate=primaryThread,
|
|
)
|
|
|
|
# Step 4.2. Create WGMMA Descriptors
|
|
a_offset = arith.muli(stage, c(lhs_tile_bytes))
|
|
a_tile_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
a_tile_shape, a_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
a_offset,
|
|
[],
|
|
)
|
|
b_offset = arith.addi(
|
|
arith.muli(stage, c(rhs_tile_bytes)),
|
|
c(lhs_tile_bytes * num_stages),
|
|
)
|
|
b_tile_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tile_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset,
|
|
[],
|
|
)
|
|
debug_print(
|
|
"[MainLoop] iv={} MMA a_offset={} b_offset={}",
|
|
iv,
|
|
a_offset,
|
|
b_offset,
|
|
predicate=primaryThread,
|
|
)
|
|
da = nvgpu.WarpgroupGenerateDescriptorOp(
|
|
a_wgmma_ty, a_tile_slice, a_tma_desc
|
|
)
|
|
db = nvgpu.WarpgroupGenerateDescriptorOp(
|
|
b_wgmma_ty, b_tile_slice, b_tma_desc
|
|
)
|
|
|
|
# Step 4.3. MMA
|
|
carry_acc = for_op.inner_iter_args[0]
|
|
new_acc = nvgpu.WarpgroupMmaOp(
|
|
acc.type, da, db, carry_acc, transposeB=True
|
|
)
|
|
if num_stages == 1:
|
|
nvvm.WgmmaWaitGroupSyncOp(0)
|
|
|
|
# Step 4.4. Load TMA for next stage
|
|
p1 = arith.cmpi(
|
|
arith.CmpIPredicate.ult,
|
|
arith.addi(iv, c(ns)),
|
|
c(K // BLOCK_K),
|
|
)
|
|
p = arith.andi(primaryThread, p1)
|
|
nextStage = arith.addi(iv, c(ns))
|
|
nextSlot = arith.remui(nextStage, c(num_stages))
|
|
a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
|
|
|
|
debug_print(
|
|
"[MainLoop] mbarTMA[{}] arrive",
|
|
nextSlot,
|
|
predicate=p,
|
|
)
|
|
nvgpu.mbarrier_arrive_expect_tx(
|
|
mbarTMA, c(txcount), nextSlot, predicate=p
|
|
)
|
|
debug_print(
|
|
"[MainLoop] mbarTMA[{}] arrive [done]",
|
|
nextSlot,
|
|
predicate=p,
|
|
)
|
|
|
|
a_tma_slice = memref.view(
|
|
ir.MemRefType.get(
|
|
a_tma_shape, a_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
a_offset,
|
|
[],
|
|
)
|
|
b_offset = arith.addi(
|
|
arith.muli(nextSlot, c(rhs_tile_bytes)),
|
|
c(lhs_tile_bytes * num_stages),
|
|
)
|
|
b_tma_slice_1 = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tma_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset,
|
|
[],
|
|
)
|
|
b_offset2 = arith.addi(
|
|
b_offset,
|
|
c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
|
|
)
|
|
b_tma_slice_2 = memref.view(
|
|
ir.MemRefType.get(
|
|
b_tma_shape, b_elem_ty, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
b_offset2,
|
|
[],
|
|
)
|
|
|
|
coord = arith.muli(c(64), nextStage)
|
|
debug_print(
|
|
"[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
|
|
iv,
|
|
a_offset,
|
|
b_offset,
|
|
b_offset2,
|
|
coord,
|
|
dimX,
|
|
dimY,
|
|
coord,
|
|
predicate=p,
|
|
)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
a_tma_slice,
|
|
mbarTMA,
|
|
a_tma_desc,
|
|
coordinates=[coord, dimX],
|
|
mbarId=nextSlot,
|
|
predicate=p,
|
|
)
|
|
nvgpu.TmaAsyncLoadOp(
|
|
b_tma_slice_1,
|
|
mbarTMA,
|
|
b_tma_desc,
|
|
coordinates=[dimY, coord],
|
|
mbarId=nextSlot,
|
|
predicate=p,
|
|
)
|
|
dimY2 = arith.addi(dimY, c(64))
|
|
nvgpu.TmaAsyncLoadOp(
|
|
b_tma_slice_2,
|
|
mbarTMA,
|
|
b_tma_desc,
|
|
coordinates=[dimY2, coord],
|
|
mbarId=nextSlot,
|
|
predicate=p,
|
|
)
|
|
# Step 4.5. Change the phaseParity
|
|
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
|
|
phaseParity = arith.select(
|
|
p,
|
|
arith.xori(phaseParity, arith.constant(T.bool(), 1)),
|
|
phaseParity,
|
|
)
|
|
|
|
# Step 4.5. Yield
|
|
scf.yield_([new_acc, phaseParity])
|
|
|
|
# Step 5. Wait All WGMMA groups
|
|
nvvm.WgmmaWaitGroupSyncOp(0)
|
|
|
|
# Step 6. Epilogue (registers --> shared memory)
|
|
acc_smem_ty = ir.MemRefType.get(
|
|
(BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
|
|
)
|
|
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
|
|
debug_print("Storing", predicate=primaryThread)
|
|
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
|
|
gpu.barrier()
|
|
|
|
# GPU Step 7. Epilogue (shared memory --> global memory)
|
|
fd = ir.MemRefType.get(
|
|
[BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
|
|
)
|
|
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
|
|
rty = ir.MemRefType.get(
|
|
(BLOCK_M, BLOCK_N),
|
|
c_elem_ty,
|
|
ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
|
|
)
|
|
c_device_per_block = memref.SubViewOp(
|
|
rty,
|
|
c_device,
|
|
[dimX, dimY],
|
|
[],
|
|
[],
|
|
[MLIR_DYNAMIC, MLIR_DYNAMIC],
|
|
[BLOCK_M, BLOCK_N],
|
|
[1, 1],
|
|
)
|
|
vlen = 1
|
|
for_op = scf.ForOp(
|
|
tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
|
|
)
|
|
with ir.InsertionPoint(for_op.body):
|
|
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
|
|
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
|
|
vdata = vector.load(
|
|
ir.VectorType.get((vlen,), c_elem_ty),
|
|
collapsed_smem,
|
|
[for_op.induction_variable],
|
|
)
|
|
vector.store(vdata, c_device_per_block, [x, y])
|
|
scf.yield_([])
|
|
|
|
gpu.terminator()
|
|
|
|
# Step 4. Copy back to host
|
|
t8 = gpu.wait(token_ty, [launch_op])
|
|
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
|
|
gpu.dealloc(token_ty, [t8], a_device)
|
|
gpu.dealloc(token_ty, [t8], b_device)
|
|
gpu.wait(token_ty, [t9])
|
|
gpu.dealloc(token_ty, [t8], c_device)
|
|
func.ReturnOp([])
|
|
|
|
fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
|
module.operation.verify()
|
|
return module
|