diff --git a/llvm/test/Other/spirv-sim/branch.spv b/llvm/test/Other/spirv-sim/branch.spv new file mode 100644 index 000000000000..7ee0ebcad249 --- /dev/null +++ b/llvm/test/Other/spirv-sim/branch.spv @@ -0,0 +1,42 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=3 --expects=5,6,6 -i %s + OpCapability Shader + OpCapability GroupNonUniform + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %WaveIndex + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %bool = OpTypeBool + %int_2 = OpConstant %int 2 + %int_5 = OpConstant %int 5 + %int_6 = OpConstant %int 6 + %uint_0 = OpConstant %uint 0 + %void = OpTypeVoid + %main_type = OpTypeFunction %void +%simple_type = OpTypeFunction %int + %uint_iptr = OpTypePointer Input %uint + %WaveIndex = OpVariable %uint_iptr Input + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + %2 = OpLoad %uint %WaveIndex + %3 = OpIEqual %bool %uint_0 %2 + OpSelectionMerge %merge None + OpBranchConditional %3 %true %false + %true = OpLabel + OpBranch %merge + %false = OpLabel + OpBranch %merge + %merge = OpLabel + %4 = OpPhi %int %int_5 %true %int_6 %false + OpReturnValue %4 + OpFunctionEnd + diff --git a/llvm/test/Other/spirv-sim/call.spv b/llvm/test/Other/spirv-sim/call.spv new file mode 100644 index 000000000000..320b048f9529 --- /dev/null +++ b/llvm/test/Other/spirv-sim/call.spv @@ -0,0 +1,36 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=1 --expects=2 -i %s + OpCapability Shader + OpCapability GroupNonUniform + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %WaveIndex + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void +%simple_type = OpTypeFunction %int + %sub_type = OpTypeFunction %uint + %uint_iptr = OpTypePointer Input %uint + %WaveIndex = OpVariable %uint_iptr Input + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %sub = OpFunction %uint None %sub_type + %a = OpLabel + OpReturnValue %uint_2 + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + %2 = OpFunctionCall %uint %sub + %3 = OpBitcast %int %2 + OpReturnValue %3 + OpFunctionEnd + + diff --git a/llvm/test/Other/spirv-sim/constant.spv b/llvm/test/Other/spirv-sim/constant.spv new file mode 100644 index 000000000000..1002427943a8 --- /dev/null +++ b/llvm/test/Other/spirv-sim/constant.spv @@ -0,0 +1,36 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=a --wave=1 --expects=2 -i %s +; RUN: spirv-sim --function=b --wave=1 --expects=1 -i %s + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %a "a" + OpName %b "b" + OpName %main "main" + %int = OpTypeInt 32 1 + %s1 = OpTypeStruct %int %int %int + %s2 = OpTypeStruct %s1 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %s1_1_2 = OpConstantComposite %s1 %int_1 %int_2 %int_1 + %s2_s1 = OpConstantComposite %s2 %s1_1_2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void + %simple_type = OpTypeFunction %int + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %a = OpFunction %int None %simple_type + %1 = OpLabel + %2 = OpCompositeExtract %int %s1_1_2 1 + OpReturnValue %2 + OpFunctionEnd + %b = OpFunction %int None %simple_type + %3 = OpLabel + %4 = OpCompositeExtract %int %s2_s1 0 2 + OpReturnValue %4 + OpFunctionEnd + diff --git a/llvm/test/Other/spirv-sim/lit.local.cfg b/llvm/test/Other/spirv-sim/lit.local.cfg new file mode 100644 index 000000000000..67a8d9196f58 --- /dev/null +++ b/llvm/test/Other/spirv-sim/lit.local.cfg @@ -0,0 +1,8 @@ +spirv_sim_root = os.path.join(config.llvm_src_root, "utils", "spirv-sim") +config.substitutions.append( + ( + "spirv-sim", + "'%s' %s" + % (config.python_executable, os.path.join(spirv_sim_root, "spirv-sim.py")), + ) +) diff --git a/llvm/test/Other/spirv-sim/loop.spv b/llvm/test/Other/spirv-sim/loop.spv new file mode 100644 index 000000000000..4fd0f1a7c96a --- /dev/null +++ b/llvm/test/Other/spirv-sim/loop.spv @@ -0,0 +1,58 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=4 --expects=0,2,2,4 -i %s + OpCapability Shader + OpCapability GroupNonUniform + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %WaveIndex + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %bool = OpTypeBool + %int_2 = OpConstant %int 2 + %int_5 = OpConstant %int 5 + %int_6 = OpConstant %int 6 + %uint_0 = OpConstant %uint 0 + %uint_2 = OpConstant %uint 2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void +%simple_type = OpTypeFunction %int + %uint_iptr = OpTypePointer Input %uint + %uint_fptr = OpTypePointer Function %uint + %WaveIndex = OpVariable %uint_iptr Input + %main = OpFunction %void None %main_type + %unused = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %entry = OpLabel +; uint i = 0; + %i = OpVariable %uint_fptr Function + %1 = OpLoad %uint %WaveIndex + OpStore %i %uint_0 + OpBranch %header + %header = OpLabel + %2 = OpLoad %uint %i + %3 = OpULessThan %bool %2 %1 + OpLoopMerge %merge %continue None + OpBranchConditional %3 %body %merge +; while (i < WaveGetLaneIndex()) { +; i += 2; +; } + %body = OpLabel + OpBranch %continue + %continue = OpLabel + %4 = OpIAdd %uint %2 %uint_2 + OpStore %i %4 + OpBranch %header + %merge = OpLabel +; return (int) i; + %5 = OpLoad %uint %i + %6 = OpBitcast %int %5 + OpReturnValue %6 + OpFunctionEnd + + diff --git a/llvm/test/Other/spirv-sim/simple-bad-result.spv b/llvm/test/Other/spirv-sim/simple-bad-result.spv new file mode 100644 index 000000000000..f4dd046cc078 --- /dev/null +++ b/llvm/test/Other/spirv-sim/simple-bad-result.spv @@ -0,0 +1,26 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: not spirv-sim --function=simple --wave=1 --expects=1 -i %s 2>&1 | FileCheck %s + +; CHECK: Expected != Observed +; CHECK: [1] != [2] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + %int = OpTypeInt 32 1 + %int_2 = OpConstant %int 2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void + %simple_type = OpTypeFunction %int + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + OpReturnValue %int_2 + OpFunctionEnd + diff --git a/llvm/test/Other/spirv-sim/simple.spv b/llvm/test/Other/spirv-sim/simple.spv new file mode 100644 index 000000000000..8c06192ea6e3 --- /dev/null +++ b/llvm/test/Other/spirv-sim/simple.spv @@ -0,0 +1,22 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=1 --expects=2 -i %s + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + %int = OpTypeInt 32 1 + %int_2 = OpConstant %int 2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void + %simple_type = OpTypeFunction %int + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + OpReturnValue %int_2 + OpFunctionEnd diff --git a/llvm/test/Other/spirv-sim/simulator-args.spv b/llvm/test/Other/spirv-sim/simulator-args.spv new file mode 100644 index 000000000000..d8b101806415 --- /dev/null +++ b/llvm/test/Other/spirv-sim/simulator-args.spv @@ -0,0 +1,36 @@ +; RUN: not spirv-sim --function=simple --wave=a --expects=2 -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-WAVE +; RUN: not spirv-sim --function=simple --wave=1 --expects=a -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-EXPECT +; RUN: not spirv-sim --function=simple --wave=1 --expects=1, -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-EXPECT +; RUN: not spirv-sim --function=simple --wave=2 --expects=1 -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-SIZE +; RUN: not spirv-sim --function=foo --wave=1 --expects=1 -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-NAME + +; CHECK-WAVE: Invalid format for --wave/-w flag. + +; CHECK-EXPECT: Invalid format for --expects/-e flag. + +; CHECK-SIZE: Wave size != expected result array size + +; CHECK-NAME: 'foo' function not found. Known functions are: +; CHECK-NAME-NEXT: - main +; CHECK-NAME-NEXT: - simple +; CHECK-NANE-NOT-NEXT: - + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + %int = OpTypeInt 32 1 + %int_2 = OpConstant %int 2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void + %simple_type = OpTypeFunction %int + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + OpReturnValue %int_2 + OpFunctionEnd diff --git a/llvm/test/Other/spirv-sim/switch.spv b/llvm/test/Other/spirv-sim/switch.spv new file mode 100644 index 000000000000..83dc56cecef2 --- /dev/null +++ b/llvm/test/Other/spirv-sim/switch.spv @@ -0,0 +1,42 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=4 --expects=0,1,2,0 -i %s + OpCapability Shader + OpCapability GroupNonUniform + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %WaveIndex + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %bool = OpTypeBool + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %uint_0 = OpConstant %uint 0 + %void = OpTypeVoid + %main_type = OpTypeFunction %void +%simple_type = OpTypeFunction %int + %uint_iptr = OpTypePointer Input %uint + %WaveIndex = OpVariable %uint_iptr Input + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + %2 = OpLoad %uint %WaveIndex + OpSelectionMerge %merge None + OpSwitch %2 %default 1 %case_1 2 %case_2 + %default = OpLabel + OpBranch %merge + %case_1 = OpLabel + OpBranch %merge + %case_2 = OpLabel + OpBranch %merge + %merge = OpLabel + %4 = OpPhi %int %int_0 %default %int_1 %case_1 %int_2 %case_2 + OpReturnValue %4 + OpFunctionEnd diff --git a/llvm/test/Other/spirv-sim/wave-get-lane-index.spv b/llvm/test/Other/spirv-sim/wave-get-lane-index.spv new file mode 100644 index 000000000000..1c1e5e8aefd4 --- /dev/null +++ b/llvm/test/Other/spirv-sim/wave-get-lane-index.spv @@ -0,0 +1,30 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=4 --expects=0,1,2,3 -i %s + OpCapability Shader + OpCapability GroupNonUniform + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %WaveIndex + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %int_2 = OpConstant %int 2 + %void = OpTypeVoid + %main_type = OpTypeFunction %void +%simple_type = OpTypeFunction %int + %uint_iptr = OpTypePointer Input %uint + %WaveIndex = OpVariable %uint_iptr Input + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %1 = OpLabel + %2 = OpLoad %uint %WaveIndex + %3 = OpBitcast %int %2 + OpReturnValue %3 + OpFunctionEnd + diff --git a/llvm/test/Other/spirv-sim/wave-read-lane-first.spv b/llvm/test/Other/spirv-sim/wave-read-lane-first.spv new file mode 100644 index 000000000000..801fb55fbaa9 --- /dev/null +++ b/llvm/test/Other/spirv-sim/wave-read-lane-first.spv @@ -0,0 +1,83 @@ +; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %} +; RUN: spirv-sim --function=simple --wave=4 --expects=0,1,2,0 -i %s + +; int simple() { +; int m[4] = { 0, 1, 2, 0 }; +; int idx = WaveGetLaneIndex(); +; for (int i = 0; i < 4; i++) { +; if (i == m[idx]) { +; return WaveReadLaneFirst(idx); +; } +; } +; return 0; +; } + OpCapability Shader + OpCapability GroupNonUniform + OpCapability GroupNonUniformBallot + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %WaveIndex + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %simple "simple" + OpName %main "main" + OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %bool = OpTypeBool + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %int_4 = OpConstant %int 4 + %uint_3 = OpConstant %uint 3 + %uint_4 = OpConstant %uint 4 + %void = OpTypeVoid + %main_type = OpTypeFunction %void + %simple_type = OpTypeFunction %int + %uint_iptr = OpTypePointer Input %uint + %int_fptr = OpTypePointer Function %int + %arr_int_uint_4 = OpTypeArray %int %uint_4 +%arr_int_uint_4_fptr = OpTypePointer Function %arr_int_uint_4 + %WaveIndex = OpVariable %uint_iptr Input + %main = OpFunction %void None %main_type + %entry = OpLabel + OpReturn + OpFunctionEnd + %simple = OpFunction %int None %simple_type + %bb_entry_0 = OpLabel + %m = OpVariable %arr_int_uint_4_fptr Function + %idx = OpVariable %int_fptr Function + %i = OpVariable %int_fptr Function + %27 = OpCompositeConstruct %arr_int_uint_4 %int_0 %int_1 %int_2 %int_0 + OpStore %m %27 + %28 = OpLoad %uint %WaveIndex + %29 = OpBitcast %int %28 + OpStore %idx %29 + OpStore %i %int_0 + OpBranch %for_check + %for_check = OpLabel + %31 = OpLoad %int %i + %33 = OpSLessThan %bool %31 %int_4 + OpLoopMerge %for_merge %for_continue None + OpBranchConditional %33 %for_body %for_merge + %for_body = OpLabel + %37 = OpLoad %int %i + %38 = OpLoad %int %idx + %39 = OpAccessChain %int_fptr %m %38 + %40 = OpLoad %int %39 + %41 = OpIEqual %bool %37 %40 + OpSelectionMerge %if_merge None + OpBranchConditional %41 %if_true %if_merge + %if_true = OpLabel + %44 = OpLoad %int %idx + %45 = OpGroupNonUniformBroadcastFirst %int %uint_3 %44 + OpReturnValue %45 + %if_merge = OpLabel + OpBranch %for_continue + %for_continue = OpLabel + %47 = OpLoad %int %i + %48 = OpIAdd %int %47 %int_1 + OpStore %i %48 + OpBranch %for_check + %for_merge = OpLabel + OpReturnValue %int_0 + OpFunctionEnd diff --git a/llvm/test/lit.cfg.py b/llvm/test/lit.cfg.py index bee7aa3903a1..1e0dd0a7df34 100644 --- a/llvm/test/lit.cfg.py +++ b/llvm/test/lit.cfg.py @@ -22,7 +22,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. This is overriden # by individual lit.local.cfg files in the test subdirectories. -config.suffixes = [".ll", ".c", ".test", ".txt", ".s", ".mir", ".yaml"] +config.suffixes = [".ll", ".c", ".test", ".txt", ".s", ".mir", ".yaml", ".spv"] # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent diff --git a/llvm/utils/spirv-sim/instructions.py b/llvm/utils/spirv-sim/instructions.py new file mode 100644 index 000000000000..5e64a480a2be --- /dev/null +++ b/llvm/utils/spirv-sim/instructions.py @@ -0,0 +1,381 @@ +from typing import Optional, List + + +# Base class for an instruction. To implement a basic instruction that doesn't +# impact the control-flow, create a new class inheriting from this. +class Instruction: + # Contains the name of the output register, if any. + _result: Optional[str] + # Contains the instruction opcode. + _opcode: str + # Contains all the instruction operands, except result and opcode. + _operands: List[str] + + def __init__(self, line: str): + self.line = line + tokens = line.split() + if len(tokens) > 1 and tokens[1] == "=": + self._result = tokens[0] + self._opcode = tokens[2] + self._operands = tokens[3:] if len(tokens) > 2 else [] + else: + self._result = None + self._opcode = tokens[0] + self._operands = tokens[1:] if len(tokens) > 1 else [] + + def __str__(self): + if self._result is None: + return f" {self._opcode} {self._operands}" + return f"{self._result:3} = {self._opcode} {self._operands}" + + # Returns the instruction opcode. + def opcode(self) -> str: + return self._opcode + + # Returns the instruction operands. + def operands(self) -> List[str]: + return self._operands + + # Returns the instruction output register. Calling this function is + # only allowed if has_output_register() is true. + def output_register(self) -> str: + assert self._result is not None + return self._result + + # Returns true if this function has an output register. False otherwise. + def has_output_register(self) -> bool: + return self._result is not None + + # This function is used to initialize state related to this instruction + # before module execution begins. For example, global Input variables + # can use this to store the lane ID into the register. + def static_execution(self, lane): + pass + + # This function is called everytime this instruction is executed by a + # tangle. This function should not be directly overriden, instead see + # _impl and _advance_ip. + def runtime_execution(self, module, lane): + self._impl(module, lane) + self._advance_ip(module, lane) + + # This function needs to be overriden if your instruction can be executed. + # It implements the logic of the instruction. + # 'Static' instructions like OpConstant should not override this since + # they are not supposed to be executed at runtime. + def _impl(self, module, lane): + raise RuntimeError(f"Unimplemented instruction {self}") + + # By default, IP is incremented to point to the next instruction. + # If the instruction modifies IP (like OpBranch), this must be overridden. + def _advance_ip(self, module, lane): + lane.set_ip(lane.ip() + 1) + + +# Those are parsed, but never executed. +class OpEntryPoint(Instruction): + pass + + +class OpFunction(Instruction): + pass + + +class OpFunctionEnd(Instruction): + pass + + +class OpLabel(Instruction): + pass + + +class OpVariable(Instruction): + pass + + +class OpName(Instruction): + def name(self) -> str: + return self._operands[1][1:-1] + + def decoratedRegister(self) -> str: + return self._operands[0] + + +# The only decoration we use if the BuiltIn one to initialize the values. +class OpDecorate(Instruction): + def static_execution(self, lane): + if self._operands[1] == "LinkageAttributes": + return + + assert ( + self._operands[1] == "BuiltIn" + and self._operands[2] == "SubgroupLocalInvocationId" + ) + lane.set_register(self._operands[0], lane.tid()) + + +# Constants +class OpConstant(Instruction): + def static_execution(self, lane): + lane.set_register(self._result, int(self._operands[1])) + + +class OpConstantTrue(OpConstant): + def static_execution(self, lane): + lane.set_register(self._result, True) + + +class OpConstantFalse(OpConstant): + def static_execution(self, lane): + lane.set_register(self._result, False) + + +class OpConstantComposite(OpConstant): + def static_execution(self, lane): + result = [] + for op in self._operands[1:]: + result.append(lane.get_register(op)) + lane.set_register(self._result, result) + + +# Control flow instructions +class OpFunctionCall(Instruction): + def _impl(self, module, lane): + pass + + def _advance_ip(self, module, lane): + entry = module.get_function_entry(self._operands[1]) + lane.do_call(entry, self._result) + + +class OpReturn(Instruction): + def _impl(self, module, lane): + pass + + def _advance_ip(self, module, lane): + lane.do_return(None) + + +class OpReturnValue(Instruction): + def _impl(self, module, lane): + pass + + def _advance_ip(self, module, lane): + lane.do_return(lane.get_register(self._operands[0])) + + +class OpBranch(Instruction): + def _impl(self, module, lane): + pass + + def _advance_ip(self, module, lane): + lane.set_ip(module.get_bb_entry(self._operands[0])) + pass + + +class OpBranchConditional(Instruction): + def _impl(self, module, lane): + pass + + def _advance_ip(self, module, lane): + condition = lane.get_register(self._operands[0]) + if condition: + lane.set_ip(module.get_bb_entry(self._operands[1])) + else: + lane.set_ip(module.get_bb_entry(self._operands[2])) + + +class OpSwitch(Instruction): + def _impl(self, module, lane): + pass + + def _advance_ip(self, module, lane): + value = lane.get_register(self._operands[0]) + default_label = self._operands[1] + i = 2 + while i < len(self._operands): + imm = int(self._operands[i]) + label = self._operands[i + 1] + if value == imm: + lane.set_ip(module.get_bb_entry(label)) + return + i += 2 + lane.set_ip(module.get_bb_entry(default_label)) + + +class OpUnreachable(Instruction): + def _impl(self, module, lane): + raise RuntimeError("This instruction should never be executed.") + + +# Convergence instructions +class MergeInstruction(Instruction): + def merge_location(self): + return self._operands[0] + + def continue_location(self): + return None if len(self._operands) < 3 else self._operands[1] + + def _impl(self, module, lane): + lane.handle_convergence_header(self) + + +class OpLoopMerge(MergeInstruction): + pass + + +class OpSelectionMerge(MergeInstruction): + pass + + +# Other instructions +class OpBitcast(Instruction): + def _impl(self, module, lane): + # TODO: find out the type from the defining instruction. + # This can only work for DXC. + if self._operands[0] == "%int": + lane.set_register(self._result, int(lane.get_register(self._operands[1]))) + else: + raise RuntimeError("Unsupported OpBitcast operand") + + +class OpAccessChain(Instruction): + def _impl(self, module, lane): + # Python dynamic types allows me to simplify. As long as the SPIR-V + # is legal, this should be fine. + # Note: SPIR-V structs are stored as tuples + value = lane.get_register(self._operands[1]) + for operand in self._operands[2:]: + value = value[lane.get_register(operand)] + lane.set_register(self._result, value) + + +class OpCompositeConstruct(Instruction): + def _impl(self, module, lane): + output = [] + for op in self._operands[1:]: + output.append(lane.get_register(op)) + lane.set_register(self._result, output) + + +class OpCompositeExtract(Instruction): + def _impl(self, module, lane): + value = lane.get_register(self._operands[1]) + output = value + for op in self._operands[2:]: + output = output[int(op)] + lane.set_register(self._result, output) + + +class OpStore(Instruction): + def _impl(self, module, lane): + lane.set_register(self._operands[0], lane.get_register(self._operands[1])) + + +class OpLoad(Instruction): + def _impl(self, module, lane): + lane.set_register(self._result, lane.get_register(self._operands[1])) + + +class OpIAdd(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS + RHS) + + +class OpISub(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS - RHS) + + +class OpIMul(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS * RHS) + + +class OpLogicalNot(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + lane.set_register(self._result, not LHS) + + +class _LessThan(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS < RHS) + + +class _GreaterThan(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS > RHS) + + +class OpSLessThan(_LessThan): + pass + + +class OpULessThan(_LessThan): + pass + + +class OpSGreaterThan(_GreaterThan): + pass + + +class OpUGreaterThan(_GreaterThan): + pass + + +class OpIEqual(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS == RHS) + + +class OpINotEqual(Instruction): + def _impl(self, module, lane): + LHS = lane.get_register(self._operands[1]) + RHS = lane.get_register(self._operands[2]) + lane.set_register(self._result, LHS != RHS) + + +class OpPhi(Instruction): + def _impl(self, module, lane): + previousBBName = lane.get_previous_bb_name() + i = 1 + while i < len(self._operands): + label = self._operands[i + 1] + if label == previousBBName: + lane.set_register(self._result, lane.get_register(self._operands[i])) + return + i += 2 + raise RuntimeError("previousBB not in the OpPhi _operands") + + +class OpSelect(Instruction): + def _impl(self, module, lane): + condition = lane.get_register(self._operands[1]) + value = lane.get_register(self._operands[2 if condition else 3]) + lane.set_register(self._result, value) + + +# Wave intrinsics +class OpGroupNonUniformBroadcastFirst(Instruction): + def _impl(self, module, lane): + assert lane.get_register(self._operands[1]) == 3 + if lane.is_first_active_lane(): + lane.broadcast_register(self._result, lane.get_register(self._operands[2])) + + +class OpGroupNonUniformElect(Instruction): + def _impl(self, module, lane): + lane.set_register(self._result, lane.is_first_active_lane()) diff --git a/llvm/utils/spirv-sim/spirv-sim.py b/llvm/utils/spirv-sim/spirv-sim.py new file mode 100755 index 000000000000..428b0ca4eb79 --- /dev/null +++ b/llvm/utils/spirv-sim/spirv-sim.py @@ -0,0 +1,658 @@ +#!/usr/bin/env python3 + +from __future__ import annotations +from dataclasses import dataclass +from instructions import * +from typing import Any, Iterable, Callable, Optional, Tuple, List, Dict +import argparse +import fileinput +import inspect +import re +import sys + +RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$") + + +# Parse the SPIR-V instructions. Some instructions are ignored because +# not required to simulate this module. +# Instructions are to be implemented in instructions.py +def parseInstruction(i): + IGNORED = set( + [ + "OpCapability", + "OpMemoryModel", + "OpExecutionMode", + "OpExtension", + "OpSource", + "OpTypeInt", + "OpTypeStruct", + "OpTypeFloat", + "OpTypeBool", + "OpTypeVoid", + "OpTypeFunction", + "OpTypePointer", + "OpTypeArray", + ] + ) + if i.opcode() in IGNORED: + return None + + try: + Type = getattr(sys.modules["instructions"], i.opcode()) + except AttributeError: + raise RuntimeError(f"Unsupported instruction {i}") + if not inspect.isclass(Type): + raise RuntimeError( + f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?" + ) + return Type(i.line) + + +# Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType. +# The delimiter is the first instruction of the next piece. +# This function returns no empty pieces: +# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second +# with the delimiter and following instructions. +# - if the first instruction is a delimiter, the first piece will begin with this delimiter. +def splitInstructions( + splitType: type, instructions: Iterable[Instruction] +) -> List[List[Instruction]]: + blocks: List[List[Instruction]] = [[]] + for instruction in instructions: + if isinstance(instruction, splitType) and len(blocks[-1]) > 0: + blocks.append([]) + blocks[-1].append(instruction) + return blocks + + +# Defines a BasicBlock in the simulator. +# Begins at an OpLabel, and ends with a control-flow instruction. +class BasicBlock: + def __init__(self, instructions) -> None: + assert isinstance(instructions[0], OpLabel) + # The name of the basic block, which is the register of the leading + # OpLabel. + self._name = instructions[0].output_register() + # The list of instructions belonging to this block. + self._instructions = instructions[1:] + + # Returns the name of this basic block. + def name(self): + return self._name + + # Returns the instruction at index in this basic block. + def __getitem__(self, index: int) -> Instruction: + return self._instructions[index] + + # Returns the number of instructions in this basic block, excluding the + # leading OpLabel. + def __len__(self): + return len(self._instructions) + + def dump(self): + print(f" {self._name}:") + for instruction in self._instructions: + print(f" {instruction}") + + +# Defines a Function in the simulator. +class Function: + def __init__(self, instructions) -> None: + assert isinstance(instructions[0], OpFunction) + # The name of the function (name of the register returned by OpFunction). + self._name: str = instructions[0].output_register() + # The list of basic blocks that belongs to this function. + self._basic_blocks: List[BasicBlock] = [] + # The variables local to this function. + self._variables: List[OpVariable] = [ + x for x in instructions if isinstance(x, OpVariable) + ] + + assert isinstance(instructions[-1], OpFunctionEnd) + body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1]) + for block in splitInstructions(OpLabel, body): + self._basic_blocks.append(BasicBlock(block)) + + # Returns the name of this function. + def name(self) -> str: + return self._name + + # Returns the basic block at index in this function. + def __getitem__(self, index: int) -> BasicBlock: + return self._basic_blocks[index] + + # Returns the index of the basic block with the given name if found, + # -1 otherwise. + def get_bb_index(self, name) -> int: + for i in range(len(self._basic_blocks)): + if self._basic_blocks[i].name() == name: + return i + return -1 + + def dump(self): + print(" Variables:") + for var in self._variables: + print(f" {var}") + print(" Blocks:") + for bb in self._basic_blocks: + bb.dump() + + +# Represents an instruction pointer in the simulator. +@dataclass +class InstructionPointer: + # The current function the IP points to. + function: Function + # The basic block index in function IP points to. + basic_block: int + # The instruction in basic_block IP points to. + instruction_index: int + + def __str__(self): + bb = self.function[self.basic_block] + i = bb[self.instruction_index] + return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}" + + def __hash__(self): + return hash((self.function.name(), self.basic_block, self.instruction_index)) + + # Returns the basic block IP points to. + def bb(self) -> BasicBlock: + return self.function[self.basic_block] + + # Returns the instruction IP points to. + def instruction(self): + return self.function[self.basic_block][self.instruction_index] + + # Increment IP by 1. This only works inside a basic-block boundary. + # Incrementing IP when at the boundary of a basic block will fail. + def __add__(self, value: int): + bb = self.function[self.basic_block] + assert len(bb) > self.instruction_index + value + return InstructionPointer( + self.function, self.basic_block, self.instruction_index + value + ) + + +# Defines a Lane in this simulator. +class Lane: + # The registers known by this lane. + _registers: Dict[str, Any] + # The current IP of this lane. + _ip: Optional[InstructionPointer] + # If this lane running. + _running: bool + # The wave this lane belongs to. + _wave: Wave + # The callstack of this lane. Each tuple represents 1 call. + # The first element is the IP the function will return to. + # The second element is the callback to call to store the return value + # into the correct register. + _callstack: List[Tuple[InstructionPointer, Callable[[Any], None]]] + + _previous_bb: Optional[BasicBlock] + _current_bb: Optional[BasicBlock] + + def __init__(self, wave: Wave, tid: int) -> None: + self._registers = dict() + self._ip = None + self._running = True + self._wave = wave + self._callstack = [] + + # The index of this lane in the wave. + self._tid = tid + # The last BB this lane was executing into. + self._previous_bb = None + # The current BB this lane is executing into. + self._current_bb = None + + # Returns the lane/thread ID of this lane in its wave. + def tid(self) -> int: + return self._tid + + # Returns true is this lane if the first by index in the current active tangle. + def is_first_active_lane(self) -> bool: + return self._tid == self._wave.get_first_active_lane_index() + + # Broadcast value into the registers of all active lanes. + def broadcast_register(self, register: str, value: Any) -> None: + self._wave.broadcast_register(register, value) + + # Returns the IP this lane is currently at. + def ip(self) -> InstructionPointer: + assert self._ip is not None + return self._ip + + # Returns true if this lane is running, false otherwise. + # Running means not dead. An inactive lane is running. + def running(self) -> bool: + return self._running + + # Set the register at "name" to "value" in this lane. + def set_register(self, name: str, value: Any) -> None: + self._registers[name] = value + + # Get the value in register "name" in this lane. + # If allow_undef is true, fetching an unknown register won't fail. + def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]: + if allow_undef and name not in self._registers: + return None + return self._registers[name] + + def set_ip(self, ip: InstructionPointer) -> None: + if ip.bb() != self._current_bb: + self._previous_bb = self._current_bb + self._current_bb = ip.bb() + self._ip = ip + + def get_previous_bb_name(self): + return self._previous_bb.name() + + def handle_convergence_header(self, instruction): + self._wave.handle_convergence_header(self, instruction) + + def do_call(self, ip, output_register): + return_ip = None if self._ip is None else self._ip + 1 + self._callstack.append( + (return_ip, lambda value: self.set_register(output_register, value)) + ) + self.set_ip(ip) + + def do_return(self, value): + ip, callback = self._callstack[-1] + self._callstack.pop() + + callback(value) + if len(self._callstack) == 0: + self._running = False + else: + self.set_ip(ip) + + +# Represents the SPIR-V module in the simulator. +class Module: + _functions: Dict[str, Function] + _prolog: List[Instruction] + _globals: List[Instruction] + _name2reg: Dict[str, str] + _reg2name: Dict[str, str] + + def __init__(self, instructions) -> None: + chunks = splitInstructions(OpFunction, instructions) + + # The instructions located outside of all functions. + self._prolog = chunks[0] + # The functions in this module. + self._functions = {} + # Global variables in this module. + self._globals = [ + x + for x in instructions + if isinstance(x, OpVariable) or issubclass(type(x), OpConstant) + ] + + # Helper dictionaries to get real names of registers, or registers by names. + self._name2reg = {} + self._reg2name = {} + for instruction in instructions: + if isinstance(instruction, OpName): + name = instruction.name() + reg = instruction.decoratedRegister() + self._name2reg[name] = reg + self._reg2name[reg] = name + + for chunk in chunks[1:]: + function = Function(chunk) + assert function.name() not in self._functions + self._functions[function.name()] = function + + # Returns the register matching "name" if any, None otherwise. + # This assumes names are unique. + def getRegisterFromName(self, name): + if name in self._name2reg: + return self._name2reg[name] + return None + + # Returns the name given to "register" if any, None otherwise. + def getNameFromRegister(self, register): + if register in self._reg2name: + return self._reg2name[register] + return None + + # Initialize the module before wave execution begins. + # See Instruction::static_execution for more details. + def initialize(self, lane): + for instruction in self._globals: + instruction.static_execution(lane) + + # Initialize builtins + for instruction in self._prolog: + if isinstance(instruction, OpDecorate): + instruction.static_execution(lane) + + def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None: + ip.instruction().runtime_execution(self, lane) + + # Returns the first valid IP for the function defined by the given register. + # Calling this with a register not returned by OpFunction is illegal. + def get_function_entry(self, register: str) -> InstructionPointer: + if register not in self._functions: + raise RuntimeError(f"Function defining {register} not found.") + return InstructionPointer(self._functions[register], 0, 0) + + # Returns the first valid IP for the basic block defined by register. + # Calling this with a register not returned by an OpLabel is illegal. + def get_bb_entry(self, register: str) -> InstructionPointer: + for name, function in self._functions.items(): + index = function.get_bb_index(register) + if index != -1: + return InstructionPointer(function, index, 0) + raise RuntimeError(f"Instruction defining {register} not found.") + + # Returns the list of function names in this module. + # If an OpName exists for this function, returns the pretty name, else + # returns the register name. + def get_function_names(self): + return [self.getNameFromRegister(reg) for reg, func in self._functions.items()] + + # Returns the global variables defined in this module. + def variables(self) -> Iterable: + return [x.output_register() for x in self._globals] + + def dump(self, function_name: Optional[str] = None): + print("Module:") + print(" globals:") + for instruction in self._globals: + print(f" {instruction}") + + if function_name is None: + print(" functions:") + for register, function in self._functions.items(): + name = self.getNameFromRegister(register) + print(f" Function {register} ({name})") + function.dump() + return + + register = self.getRegisterFromName(function_name) + print(f" function {register} ({function_name}):") + if register is not None: + self._functions[register].dump() + else: + print(f" error: cannot find function.") + + +# Defines a convergence requirement for the simulation: +# A list of lanes impacted by a merge and possibly the associated +# continue target. +@dataclass +class ConvergenceRequirement: + mergeTarget: InstructionPointer + continueTarget: Optional[InstructionPointer] + impactedLanes: set[int] + + +Task = Dict[InstructionPointer, List[Lane]] + + +# Defines a Lane group/Wave in the simulator. +class Wave: + # The module this wave will execute. + _module: Module + # The lanes this wave will be composed of. + _lanes: List[Lane] + # The instructions scheduled for execution. + _tasks: Task + # The actual requirements to comply with when executing instructions. + # E.g: the set of lanes required to merge before executing the merge block. + _convergence_requirements: List[ConvergenceRequirement] + # The indices of the active lanes for the current executing instruction. + _active_lane_indices: set[int] + + def __init__(self, module, wave_size: int) -> None: + assert wave_size > 0 + self._module = module + self._lanes = [] + + for i in range(wave_size): + self._lanes.append(Lane(self, i)) + + self._tasks = {} + self._convergence_requirements = [] + # The indices of the active lanes for the current executing instruction. + self._active_lane_indices = set() + + # Returns True if the given IP can be executed for the given list of lanes. + def _is_task_candidate(self, ip: InstructionPointer, lanes: List[Lane]): + merged_lanes: set[int] = set() + for lane in self._lanes: + if not lane.running(): + merged_lanes.add(lane.tid()) + + for requirement in self._convergence_requirements: + # This task is not executing a merge or continue target. + # Adding all lanes at those points into the ignore list. + if requirement.mergeTarget != ip and requirement.continueTarget != ip: + for tid in requirement.impactedLanes: + if self._lanes[tid].ip() == requirement.mergeTarget: + merged_lanes.add(tid) + if self._lanes[tid].ip() == requirement.continueTarget: + merged_lanes.add(tid) + continue + + # This task is executing the current requirement continue/merge + # target. + for tid in requirement.impactedLanes: + lane = self._lanes[tid] + if not lane.running(): + continue + + if lane.tid() in merged_lanes: + continue + + if ip == requirement.mergeTarget: + if lane.ip() != requirement.mergeTarget: + return False + else: + if ( + lane.ip() != requirement.mergeTarget + and lane.ip() != requirement.continueTarget + ): + return False + return True + + # Returns the next task we can schedule. This must always return a task. + # Calling this when all lanes are dead is invalid. + def _get_next_runnable_task(self) -> Tuple[InstructionPointer, List[Lane]]: + candidate = None + for ip, lanes in self._tasks.items(): + if len(lanes) == 0: + continue + if self._is_task_candidate(ip, lanes): + candidate = ip + break + + if candidate: + lanes = self._tasks[candidate] + del self._tasks[ip] + return (candidate, lanes) + raise RuntimeError("No task to execute. Deadlock?") + + # Handle an encountered merge instruction for the given lane. + def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction): + mergeTarget = self._module.get_bb_entry(instruction.merge_location()) + for requirement in self._convergence_requirements: + if requirement.mergeTarget == mergeTarget: + requirement.impactedLanes.add(lane.tid()) + return + + continueTarget = None + if instruction.continue_location(): + continueTarget = self._module.get_bb_entry(instruction.continue_location()) + requirement = ConvergenceRequirement( + mergeTarget, continueTarget, set([lane.tid()]) + ) + self._convergence_requirements.append(requirement) + + # Returns true if some instructions are scheduled for execution. + def _has_tasks(self) -> bool: + return len(self._tasks) > 0 + + # Returns the index of the first active lane right now. + def get_first_active_lane_index(self) -> int: + return min(self._active_lane_indices) + + # Broadcast the given value to all active lane registers. + def broadcast_register(self, register: str, value: Any) -> None: + for tid in self._active_lane_indices: + self._lanes[tid].set_register(register, value) + + # Returns the entrypoint of the function associated with 'name'. + # Calling this function with an invalid name is illegal. + def _get_function_entry_from_name(self, name: str) -> InstructionPointer: + register = self._module.getRegisterFromName(name) + assert register is not None + return self._module.get_function_entry(register) + + # Run the wave on the function 'function_name' until all lanes are dead. + # If verbose is True, execution trace is printed. + # Returns the value returned by the function for each lane. + def run(self, function_name: str, verbose: bool = False) -> List[Any]: + for t in self._lanes: + self._module.initialize(t) + + entry_ip = self._get_function_entry_from_name(function_name) + assert entry_ip is not None + for t in self._lanes: + t.do_call(entry_ip, "__shader_output__") + + self._tasks[self._lanes[0].ip()] = self._lanes + while self._has_tasks(): + ip, lanes = self._get_next_runnable_task() + self._active_lane_indices = set([x.tid() for x in lanes]) + if verbose: + print( + f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}" + ) + + for lane in lanes: + self._module.execute_one_instruction(lane, ip) + if not lane.running(): + continue + + if lane.ip() in self._tasks: + self._tasks[lane.ip()].append(lane) + else: + self._tasks[lane.ip()] = [lane] + + if verbose and ip.instruction().has_output_register(): + register = ip.instruction().output_register() + print( + f" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}" + ) + + output = [] + for lane in self._lanes: + output.append(lane.get_register("__shader_output__")) + return output + + def dump_register(self, register: str) -> None: + for lane in self._lanes: + print( + f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}" + ) + + +parser = argparse.ArgumentParser( + description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter +) +parser.add_argument( + "-i", "--input", help="Text SPIR-V to read from", required=False, default="-" +) +parser.add_argument("-f", "--function", help="Function to execute") +parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False) +parser.add_argument( + "-e", + "--expects", + help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.", +) +parser.add_argument("-v", "--verbose", help="verbose", action="store_true") +args = parser.parse_args() + + +def load_instructions(filename: str): + if filename is None: + return [] + + if filename.strip() != "-": + try: + with open(filename, "r") as f: + lines = f.read().split("\n") + except Exception: # (FileNotFoundError, PermissionError): + return [] + else: + lines = sys.stdin.readlines() + + # Remove leading/trailing whitespaces. + lines = [x.strip() for x in lines] + # Strip comments. + lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)] + + instructions = [] + for i in [Instruction(x) for x in lines]: + out = parseInstruction(i) + if out != None: + instructions.append(out) + return instructions + + +def main(): + if args.expects is None or not RE_EXPECTS.match(args.expects): + print("Invalid format for --expects/-e flag.", file=sys.stderr) + sys.exit(1) + if args.function is None: + print("Invalid format for --function/-f flag.", file=sys.stderr) + sys.exit(1) + try: + int(args.wave) + except ValueError: + print("Invalid format for --wave/-w flag.", file=sys.stderr) + sys.exit(1) + + expected_results = [int(x.strip()) for x in args.expects.split(",")] + wave_size = int(args.wave) + if len(expected_results) != wave_size: + print("Wave size != expected result array size", file=sys.stderr) + sys.exit(1) + + instructions = load_instructions(args.input) + if len(instructions) == 0: + print("Invalid input. Expected a text SPIR-V module.") + sys.exit(1) + + module = Module(instructions) + if args.verbose: + module.dump() + module.dump(args.function) + + function_names = module.get_function_names() + if args.function not in function_names: + print( + f"'{args.function}' function not found. Known functions are:", + file=sys.stderr, + ) + for name in function_names: + print(f" - {name}", file=sys.stderr) + sys.exit(1) + + wave = Wave(module, wave_size) + results = wave.run(args.function, verbose=args.verbose) + + if expected_results != results: + print("Expected != Observed", file=sys.stderr) + print(f"{expected_results} != {results}", file=sys.stderr) + sys.exit(1) + sys.exit(0) + + +main()