457 lines
16 KiB
Python
457 lines
16 KiB
Python
from enum import Enum
|
|
import functools, sys, ctypes, os, errno
|
|
import numpy as np
|
|
from functools import partialmethod
|
|
from mlir import ir
|
|
from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
|
|
from mlir.extras import types as T
|
|
from mlir import runtime as rt
|
|
from tools import nvgpucompiler
|
|
|
|
MLIR_DYNAMIC = -9223372036854775808
|
|
|
|
|
|
def const(value: int, ty=None):
|
|
ty = T.index() if ty is None else ty
|
|
if isinstance(value, ir.Value) and (
|
|
value.type.isinstance(value.type) or T.bool().isinstance(value.type)
|
|
):
|
|
return value
|
|
return arith.constant(ty, value)
|
|
|
|
|
|
def get_type_size(ty):
|
|
if ir.MemRefType.isinstance(ty):
|
|
size = get_type_size(ty.element_type)
|
|
for sz in ty.shape:
|
|
size *= sz
|
|
return size
|
|
if ir.FloatType.isinstance(ty):
|
|
return ir.FloatType(ty).width // 8
|
|
if ir.IntegerType.isinstance(ty):
|
|
return ir.IntegerType(ty).width // 8
|
|
raise NotImplementedError(ty)
|
|
|
|
|
|
def get_mlir_func_obj_ty(inputArgs):
|
|
args = []
|
|
c_int_p = ctypes.c_int * 1
|
|
c_float_p = ctypes.c_float * 1
|
|
c_bool_p = ctypes.c_bool * 1
|
|
for arg in inputArgs:
|
|
if isinstance(arg, bool):
|
|
args.append(c_bool_p(arg))
|
|
elif isinstance(arg, int):
|
|
args.append(c_int_p(arg))
|
|
elif isinstance(arg, float):
|
|
args.append(c_float_p(arg))
|
|
elif isinstance(arg, np.ndarray):
|
|
args.append(
|
|
ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
|
|
)
|
|
else:
|
|
raise NotImplementedError(arg)
|
|
return args
|
|
|
|
|
|
class Mbarriers:
|
|
def __init__(self, number_of_barriers=1):
|
|
self.mbar_ty = ir.Type.parse(
|
|
"!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
|
|
+ str(number_of_barriers)
|
|
+ ">"
|
|
)
|
|
self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
|
|
self.number_of_barriers = number_of_barriers
|
|
|
|
def __getitem__(self, key):
|
|
self.id_op = const(key)
|
|
return self
|
|
|
|
def init(self, count: int, predicate=None):
|
|
count_op = const(count)
|
|
if predicate is None:
|
|
nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
|
|
else:
|
|
nvgpu.mbarrier_init(
|
|
self.mbar_group_op, count_op, self.id_op, predicate=predicate
|
|
)
|
|
|
|
def arrive(self, txcount: int = 0, predicate=None):
|
|
if txcount != 0:
|
|
txcount_op = const(txcount)
|
|
nvgpu.mbarrier_arrive_expect_tx(
|
|
self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
|
|
)
|
|
else:
|
|
nvgpu.mbarrier_arrive(
|
|
ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
|
|
)
|
|
|
|
def try_wait(self, phase: bool = False, ticks: int = 10000000):
|
|
ticks_op = const(ticks)
|
|
phase_op = const(phase, T.bool())
|
|
nvgpu.MBarrierTryWaitParityOp(
|
|
self.mbar_group_op,
|
|
phase_op,
|
|
ticks_op,
|
|
mbarId=self.id_op,
|
|
)
|
|
|
|
|
|
class TMA:
|
|
"""A class that builds a TMA descriptor."""
|
|
|
|
def __init__(
|
|
self,
|
|
tma_box_shape,
|
|
memref_ty,
|
|
swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
|
|
l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
|
|
oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
|
|
interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
|
|
):
|
|
self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
|
|
self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
|
|
self.oob = oob # mlir.nvgpu.TensorMapOOBKind
|
|
self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
|
|
self.tma_box_shape = tma_box_shape
|
|
self.memref_ty = memref_ty # MemRefType
|
|
self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type)
|
|
|
|
@property
|
|
def tensormap_descriptor_ty(self):
|
|
"""Returns a tensormap descriptor type."""
|
|
tensorMemrefType = ir.MemRefType.get(
|
|
self.tma_box_shape,
|
|
self.memref_ty.element_type,
|
|
memory_space=ir.Attribute.parse("3"),
|
|
)
|
|
return nvgpu.TensorMapDescriptorType.get(
|
|
tensorMemrefType,
|
|
self.swizzle,
|
|
self.l2promo,
|
|
self.oob,
|
|
self.interleave,
|
|
)
|
|
|
|
def create_descriptor(self, device_ptr):
|
|
tma_descriptor_ty = self.tensormap_descriptor_ty
|
|
device_unranked_memref = memref.CastOp(
|
|
ir.UnrankedMemRefType.get(
|
|
self.memref_ty.element_type, self.memref_ty.memory_space
|
|
),
|
|
device_ptr,
|
|
)
|
|
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
|
|
tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
|
|
)
|
|
return self.tma_descriptor.result
|
|
|
|
def prefetch(self, predicate=None):
|
|
nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
|
|
|
|
def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
|
|
nvgpu.TmaAsyncLoadOp(
|
|
dest,
|
|
mbarrier.mbar_group_op,
|
|
self.tma_descriptor,
|
|
coordinates=map(const, coords),
|
|
mbarId=mbarrier.id_op,
|
|
predicate=predicate,
|
|
)
|
|
|
|
|
|
WARP_GROUP_SIZE = 128 # Number of threads in a warpgroup
|
|
|
|
|
|
class Warpgroup:
|
|
def __init__(self, primary_thread, register_size):
|
|
assert (primary_thread % WARP_GROUP_SIZE) == 0
|
|
tidx = gpu.thread_id(gpu.Dimension.x)
|
|
self.primary_thread = primary_thread
|
|
self.register_size = register_size
|
|
self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
|
|
self.wg_id = tidx / WARP_GROUP_SIZE
|
|
self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
|
|
|
|
def __enter__(self):
|
|
if_op = scf.IfOp(self.is_me)
|
|
self.ipoint_op = ir.InsertionPoint(if_op.then_block)
|
|
self.ipoint_op.__enter__()
|
|
if self.register_size < 64:
|
|
nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
|
|
else:
|
|
nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
scf.yield_([])
|
|
self.ipoint_op.__exit__(exc_type, exc_value, traceback)
|
|
return True
|
|
|
|
|
|
class WGMMAType(Enum):
|
|
Accumulator = 1
|
|
Descriptor = 2
|
|
|
|
|
|
class WGMMAMatrix:
|
|
def __init__(
|
|
self,
|
|
matrix_type: WGMMAType,
|
|
shape: list = None,
|
|
desc: TMA = None,
|
|
smem=None,
|
|
ty=None,
|
|
acc_op=None,
|
|
):
|
|
if acc_op is None:
|
|
self.M = shape[0]
|
|
self.N = shape[1]
|
|
self.ty = ty
|
|
self.matrix_type = matrix_type
|
|
self.desc = desc
|
|
self.smem = smem
|
|
if matrix_type is WGMMAType.Accumulator:
|
|
self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
|
|
elif acc_op:
|
|
self.acc_op = acc_op
|
|
self.matrix_type = WGMMAType.Accumulator
|
|
|
|
@property
|
|
def acc_ty(self):
|
|
parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
|
|
return ir.Type.parse(parse_str)
|
|
|
|
@property
|
|
def wgmma_ty(self):
|
|
parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
|
|
return ir.Type.parse(parse_str)
|
|
|
|
def store_accumulator(self, dest):
|
|
assert self.matrix_type == WGMMAType.Accumulator
|
|
nvgpu.warpgroup_mma_store(self.acc_op, dest)
|
|
|
|
def update_smem(self, smem):
|
|
self.smem = smem
|
|
|
|
def update_accumulator(self, acc_op):
|
|
self.acc_op = acc_op
|
|
|
|
def __matmul__(self, rhs):
|
|
lhs = nvgpu.warpgroup_generate_descriptor(
|
|
self.wgmma_ty, self.smem, self.desc.tma_descriptor
|
|
)
|
|
rhs = nvgpu.warpgroup_generate_descriptor(
|
|
rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
|
|
)
|
|
return [lhs, rhs]
|
|
|
|
def __iadd__(self, matmulResult):
|
|
lhs = matmulResult[0]
|
|
rhs = matmulResult[1]
|
|
acc_op = nvgpu.WarpgroupMmaOp(
|
|
self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
|
|
)
|
|
return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
|
|
|
|
|
|
def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
|
|
smem_space_str = "#gpu.address_space<workgroup>"
|
|
smem_space = ir.Attribute.parse(smem_space_str)
|
|
dynamic_smem = gpu.dynamic_shared_memory(
|
|
ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
|
|
)
|
|
if shape is None:
|
|
return dynamic_smem
|
|
memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
|
|
return memref.view(
|
|
ir.MemRefType.get(
|
|
memref_ty.shape, memref_ty.element_type, memory_space=smem_space
|
|
),
|
|
dynamic_smem,
|
|
const(offset),
|
|
[],
|
|
)
|
|
|
|
|
|
def get_mlir_ty(arg):
|
|
def get_mlir_ty_from_np(dtype):
|
|
if dtype == np.float16:
|
|
return T.f16()
|
|
if dtype == np.float32:
|
|
return T.f32()
|
|
if dtype == np.float64:
|
|
return T.f64()
|
|
if dtype == np.int32:
|
|
return T.i32()
|
|
if dtype == np.int64:
|
|
return T.i64()
|
|
raise NotImplementedError(dtype)
|
|
|
|
if isinstance(arg, bool):
|
|
return T.bool()
|
|
elif isinstance(arg, int):
|
|
return T.index()
|
|
elif isinstance(arg, float):
|
|
return T.f32()
|
|
elif isinstance(arg, np.ndarray):
|
|
descriptor = rt.get_ranked_memref_descriptor(arg)
|
|
dtype = get_mlir_ty_from_np(arg.dtype)
|
|
shape = descriptor.shape
|
|
return memref.MemRefType.get(shape, dtype)
|
|
raise NotImplementedError(arg)
|
|
|
|
|
|
class NVDSL:
|
|
@staticmethod
|
|
def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
launch_op = gpu.LaunchOp(
|
|
None,
|
|
[],
|
|
*map(const, grid),
|
|
*map(const, block),
|
|
dynamicSharedMemorySize=arith.constant(T.i32(), smem),
|
|
)
|
|
launch_op.body.blocks.append(*([T.index()] * 12))
|
|
with ir.InsertionPoint(launch_op.body.blocks[0]):
|
|
result = func(*args, **kwargs)
|
|
gpu.terminator()
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
@staticmethod
|
|
def mlir_func(funcBody):
|
|
@functools.wraps(funcBody)
|
|
def wrapper(*args, **kwargs):
|
|
function_name = funcBody.__name__
|
|
|
|
def saveIR(module):
|
|
"""Save generated IR"""
|
|
if True: # self.saveIR:
|
|
# print(mlir_nvgpu_module)
|
|
original_stdout = sys.stdout
|
|
with open("nvdsl.mlir", "w") as f:
|
|
sys.stdout = f
|
|
print(module)
|
|
sys.stdout = original_stdout
|
|
|
|
def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
|
|
"""Generate MLIR's Arith dialects binary operations."""
|
|
rhs = const(rhs)
|
|
if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
|
|
op += "F"
|
|
if op.startswith("Cmp"):
|
|
predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
|
|
predAtt
|
|
]
|
|
elif arith._is_integer_like_type(
|
|
lhs.type
|
|
) and arith._is_integer_like_type(lhs.type):
|
|
if op == "Div" or op == "Rem":
|
|
op += "U"
|
|
op += "I"
|
|
if op.startswith("Cmp"):
|
|
predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
|
|
predAtt
|
|
]
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unsupported '{op}' operands: {lhs}, {rhs}"
|
|
)
|
|
|
|
if op.startswith("Cmp"):
|
|
op = getattr(arith, f"{op}Op")
|
|
|
|
return op(predicateAttr, lhs, rhs).result
|
|
else:
|
|
op = getattr(arith, f"{op}Op")
|
|
return op(lhs, rhs).result
|
|
|
|
@ir.register_value_caster(ir.IndexType.static_typeid)
|
|
@ir.register_value_caster(ir.F32Type.static_typeid)
|
|
@ir.register_value_caster(ir.F16Type.static_typeid)
|
|
@ir.register_value_caster(ir.F64Type.static_typeid)
|
|
@ir.register_value_caster(ir.IntegerType.static_typeid)
|
|
class ArithValue(ir.Value):
|
|
"""Overloads operators for MLIR's Arith dialects binary operations."""
|
|
|
|
def __init__(self, v):
|
|
super().__init__(v)
|
|
|
|
__add__ = partialmethod(_binary_op, op="Add")
|
|
__sub__ = partialmethod(_binary_op, op="Sub")
|
|
__mul__ = partialmethod(_binary_op, op="Mul")
|
|
__truediv__ = partialmethod(_binary_op, op="Div")
|
|
__mod__ = partialmethod(_binary_op, op="Rem")
|
|
__xor__ = partialmethod(_binary_op, op="XOr")
|
|
__lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
|
|
__le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
|
|
__eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
|
|
__ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
|
|
__gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
|
|
__ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
|
|
__and__ = partialmethod(_binary_op, op="And")
|
|
__or__ = partialmethod(_binary_op, op="Or")
|
|
|
|
def __str__(self):
|
|
return (
|
|
super()
|
|
.__str__()
|
|
.replace(ir.Value.__name__, ArithValue.__name__)
|
|
)
|
|
|
|
# Generate MLIR Context and start generating IR
|
|
with ir.Context(), ir.Location.unknown():
|
|
types = []
|
|
for arg in args:
|
|
types.append(get_mlir_ty(arg))
|
|
|
|
# Build IR
|
|
module = ir.Module.create()
|
|
with ir.InsertionPoint(module.body):
|
|
fop = func.FuncOp(function_name, (types, []))
|
|
fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
|
with ir.InsertionPoint(fop.add_entry_block()):
|
|
fargs = []
|
|
for i, a in enumerate(types):
|
|
fargs.append(fop.arguments[i])
|
|
|
|
# Call user function body
|
|
result = funcBody(*fargs, **kwargs)
|
|
func.ReturnOp([])
|
|
|
|
# Save IR in a file
|
|
# saveIR(module)
|
|
|
|
# Verify the module
|
|
module.operation.verify()
|
|
|
|
# Compile and JIT MLIR module
|
|
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
|
|
support_lib = os.getenv("SUPPORT_LIB")
|
|
if not os.path.exists(support_lib):
|
|
raise FileNotFoundError(
|
|
errno.ENOENT, os.strerror(errno.ENOENT), support_lib
|
|
)
|
|
compiler = nvgpucompiler.NvgpuCompiler(
|
|
options, opt_level=3, shared_libs=[support_lib]
|
|
)
|
|
engine = compiler.compile_and_jit(module)
|
|
|
|
# Convert input arguments to MLIR arguments
|
|
newArgs = get_mlir_func_obj_ty(args)
|
|
|
|
# Run the compiled program
|
|
engine.invoke(function_name, *newArgs)
|
|
|
|
return result
|
|
|
|
return wrapper
|