diff --git a/docs/src/ttrt.md b/docs/src/ttrt.md index 3f13e0471..c9a0da352 100644 --- a/docs/src/ttrt.md +++ b/docs/src/ttrt.md @@ -46,6 +46,7 @@ ttrt run out.ttnn ttrt run out.ttnn --clean-artifacts ttrt run out.ttnn --save-artifacts ttrt run out.ttnn --loops 10 +ttrt run --program-index all out.ttnn ttrt run --program-index 0 out.ttnn ttrt run /dir/of/flatbuffers ttrt run /dir/of/flatbuffers --loops 10 @@ -74,6 +75,7 @@ ttrt perf out.ttnn ttrt perf out.ttnn --clean-artifacts ttrt perf out.ttnn --save-artifacts ttrt perf out.ttnn --loops 10 +ttrt perf --program-index all out.ttnn ttrt perf --program-index 0 out.ttnn ttrt perf --device out.ttnn ttrt perf --generate-params --perf-csv trace.csv diff --git a/runtime/tools/python/ttrt/__init__.py b/runtime/tools/python/ttrt/__init__.py index be197112c..f7883c980 100644 --- a/runtime/tools/python/ttrt/__init__.py +++ b/runtime/tools/python/ttrt/__init__.py @@ -62,7 +62,7 @@ def main(): run_parser = subparsers.add_parser("run", help="run a flatbuffer binary") run_parser.add_argument( "--program-index", - default=0, + default="all", help="the program inside the fbb to run", ) run_parser.add_argument( @@ -129,7 +129,7 @@ def main(): ) perf_parser.add_argument( "--program-index", - default=0, + default="all", help="the program inside the fbb to run", ) perf_parser.add_argument( diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 50aded72c..d475edd30 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -94,7 +94,7 @@ def run(args): # acquire parameters arg_binary = args.binary - arg_program_index = int(args.program_index) + arg_program_index = args.program_index arg_clean_artifacts = args.clean_artifacts arg_loops = int(args.loops) arg_save_artifacts = args.save_artifacts @@ -117,19 +117,34 @@ def run(args): # constraint checking print("executing constraint for all provided flatbuffers") system_desc, device_ids = ttrt.runtime.get_current_system_desc() + program_indices = [] for binary in binaries: check_file_exists(binary) fbb = ttrt.binary.load_binary_from_path(binary) check_version(fbb.version) fbb_dict = ttrt.binary.as_dict(fbb) + assert ( fbb_dict["system_desc"] == system_desc_as_dict(system_desc)["system_desc"] ), f"system descriptor for binary and system mismatch!" - fbb_list.append((os.path.splitext(os.path.basename(binary))[0], fbb, fbb_dict)) - program_index = arg_program_index - assert program_index <= len( - fbb_dict["programs"] - ), "args.program_index out of range" + + if arg_program_index != "all": + program_index = int(arg_program_index) + assert program_index < len( + fbb_dict["programs"] + ), "args.program_index out of range" + program_indices.append(program_index) + else: + program_indices = [i for i in range(len(fbb_dict["programs"]))] + + fbb_list.append( + ( + os.path.splitext(os.path.basename(binary))[0], + fbb, + fbb_dict, + program_indices, + ) + ) # execution print("executing action for all provided flatbuffers") @@ -138,88 +153,104 @@ def run(args): torch.manual_seed(args.seed) - for (binary_name, fbb, fbb_dict) in fbb_list: - torch_inputs[binary_name] = [] - torch_outputs[binary_name] = [] - program = fbb_dict["programs"][program_index] - print( - f"running binary={binary_name} with program[{program_index}]:", - program["name"], - ) + for (binary_name, fbb, fbb_dict, program_indices) in fbb_list: + torch_inputs[binary_name] = {} + torch_outputs[binary_name] = {} - for i in program["inputs"]: - torch_tensor = torch.randn( - i["desc"]["shape"], - dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]), - ) - torch_inputs[binary_name].append(torch_tensor) - for i in program["outputs"]: - torch_tensor = torch.zeros( - i["desc"]["shape"], - dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]), + for program_index in program_indices: + torch_inputs[binary_name][program_index] = [] + torch_outputs[binary_name][program_index] = [] + + program = fbb_dict["programs"][program_index] + print( + f"running binary={binary_name} with program[{program_index}]:", + program["name"], ) - torch_outputs[binary_name].append(torch_tensor) - - print("inputs:\n", torch_inputs) - - total_inputs = [] - total_outputs = [] - for loop in range(arg_loops): - inputs = [] - outputs = [] - for i in torch_inputs[binary_name]: - inputs.append( - ttrt.runtime.create_tensor( - i.data_ptr(), - list(i.shape), - list(i.stride()), - i.element_size(), - toDataType(i.dtype), - ) + + for i in program["inputs"]: + torch_tensor = torch.randn( + i["desc"]["shape"], + dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]), + ) + torch_inputs[binary_name][program_index].append(torch_tensor) + for i in program["outputs"]: + torch_tensor = torch.zeros( + i["desc"]["shape"], + dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]), ) + torch_outputs[binary_name][program_index].append(torch_tensor) + + print("inputs:\n", torch_inputs) + + total_inputs = [] + total_outputs = [] + for loop in range(arg_loops): + inputs = [] + outputs = [] + for i in torch_inputs[binary_name][program_index]: + inputs.append( + ttrt.runtime.create_tensor( + i.data_ptr(), + list(i.shape), + list(i.stride()), + i.element_size(), + toDataType(i.dtype), + ) + ) - for i in torch_outputs[binary_name]: - outputs.append( - ttrt.runtime.create_tensor( - i.data_ptr(), - list(i.shape), - list(i.stride()), - i.element_size(), - toDataType(i.dtype), + for i in torch_outputs[binary_name][program_index]: + outputs.append( + ttrt.runtime.create_tensor( + i.data_ptr(), + list(i.shape), + list(i.stride()), + i.element_size(), + toDataType(i.dtype), + ) ) - ) - total_inputs.append(inputs) - total_outputs.append(outputs) + total_inputs.append(inputs) + total_outputs.append(outputs) - for loop in range(arg_loops): - ttrt.runtime.submit(device, fbb, 0, total_inputs[loop], total_outputs[loop]) - print(f"finished loop={loop}") - print("outputs:\n", torch_outputs) + for loop in range(arg_loops): + ttrt.runtime.submit( + device, fbb, program_index, total_inputs[loop], total_outputs[loop] + ) + print(f"finished loop={loop}") + print("outputs:\n", torch_outputs) # save artifacts if arg_save_artifacts: print("saving artifacts") for binary in binaries: - copy_ttnn_binary_into_artifact(binary) - binary_name = os.path.splitext(os.path.basename(binary))[0] - torch_input_tensors = torch_inputs[binary_name] - torch_output_tensors = torch_outputs[binary_name] + fbb_dict = ttrt.binary.as_dict(ttrt.binary.load_binary_from_path(binary)) + curr_program_indices = [] + + if arg_program_index != "all": + curr_program_indices.append(int(arg_program_index)) + else: + curr_program_indices = [i for i in range(len(fbb_dict["programs"]))] + + for program_index in curr_program_indices: + copy_ttnn_binary_into_artifact(binary) + binary_name = os.path.splitext(os.path.basename(binary))[0] + torch_input_tensors = torch_inputs[binary_name][program_index] + torch_output_tensors = torch_outputs[binary_name][program_index] + + for i, input in enumerate(torch_input_tensors): + save_torch_tensor_into_ttrt_artifacts( + input, f"{binary_name}/program_{program_index}_input_{i}.pt" + ) - for i, input in enumerate(torch_input_tensors): - save_torch_tensor_into_ttrt_artifacts( - input, f"{binary_name}/input_{i}.pt" - ) + for i, output in enumerate(torch_output_tensors): + save_torch_tensor_into_ttrt_artifacts( + output, f"{binary_name}/program_{program_index}_output_{i}.pt" + ) - for i, output in enumerate(torch_output_tensors): - save_torch_tensor_into_ttrt_artifacts( - output, f"{binary_name}/output_{i}.pt" + save_system_desc_into_ttrt_artifacts( + system_desc, f"{binary_name}/system_desc.ttsys" ) - save_system_desc_into_ttrt_artifacts( - system_desc, f"{binary_name}/system_desc.ttsys" - ) - """ API: query @@ -278,7 +309,7 @@ def perf(args): # acquire parameters arg_binary = args.binary - arg_program_index = int(args.program_index) + arg_program_index = args.program_index arg_clean_artifacts = args.clean_artifacts arg_perf_csv = args.perf_csv arg_loops = int(args.loops)