Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#297: Add support in ttrt for multiple programs within a single flatbuffer file #301

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/ttrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ttrt run out.ttnn
ttrt run out.ttnn --clean-artifacts
ttrt run out.ttnn --save-artifacts
ttrt run out.ttnn --loops 10
ttrt run --program-index all out.ttnn
ttrt run --program-index 0 out.ttnn
ttrt run /dir/of/flatbuffers
ttrt run /dir/of/flatbuffers --loops 10
Expand Down Expand Up @@ -74,6 +75,7 @@ ttrt perf out.ttnn
ttrt perf out.ttnn --clean-artifacts
ttrt perf out.ttnn --save-artifacts
ttrt perf out.ttnn --loops 10
ttrt perf --program-index all out.ttnn
ttrt perf --program-index 0 out.ttnn
ttrt perf --device out.ttnn
ttrt perf --generate-params --perf-csv trace.csv
Expand Down
4 changes: 2 additions & 2 deletions runtime/tools/python/ttrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main():
run_parser = subparsers.add_parser("run", help="run a flatbuffer binary")
run_parser.add_argument(
"--program-index",
default=0,
default="all",
help="the program inside the fbb to run",
)
run_parser.add_argument(
Expand Down Expand Up @@ -129,7 +129,7 @@ def main():
)
perf_parser.add_argument(
"--program-index",
default=0,
default="all",
help="the program inside the fbb to run",
)
perf_parser.add_argument(
Expand Down
177 changes: 104 additions & 73 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run(args):

# acquire parameters
arg_binary = args.binary
arg_program_index = int(args.program_index)
arg_program_index = args.program_index
arg_clean_artifacts = args.clean_artifacts
arg_loops = int(args.loops)
arg_save_artifacts = args.save_artifacts
Expand All @@ -117,19 +117,34 @@ def run(args):
# constraint checking
print("executing constraint for all provided flatbuffers")
system_desc, device_ids = ttrt.runtime.get_current_system_desc()
program_indices = []
for binary in binaries:
check_file_exists(binary)
fbb = ttrt.binary.load_binary_from_path(binary)
check_version(fbb.version)
fbb_dict = ttrt.binary.as_dict(fbb)

assert (
fbb_dict["system_desc"] == system_desc_as_dict(system_desc)["system_desc"]
), f"system descriptor for binary and system mismatch!"
fbb_list.append((os.path.splitext(os.path.basename(binary))[0], fbb, fbb_dict))
program_index = arg_program_index
assert program_index <= len(
fbb_dict["programs"]
), "args.program_index out of range"

if arg_program_index != "all":
program_index = int(arg_program_index)
assert program_index < len(
fbb_dict["programs"]
), "args.program_index out of range"
program_indices.append(program_index)
else:
program_indices = [i for i in range(len(fbb_dict["programs"]))]

fbb_list.append(
(
os.path.splitext(os.path.basename(binary))[0],
fbb,
fbb_dict,
program_indices,
)
)

# execution
print("executing action for all provided flatbuffers")
Expand All @@ -138,88 +153,104 @@ def run(args):

torch.manual_seed(args.seed)

for (binary_name, fbb, fbb_dict) in fbb_list:
torch_inputs[binary_name] = []
torch_outputs[binary_name] = []
program = fbb_dict["programs"][program_index]
print(
f"running binary={binary_name} with program[{program_index}]:",
program["name"],
)
for (binary_name, fbb, fbb_dict, program_indices) in fbb_list:
torch_inputs[binary_name] = {}
torch_outputs[binary_name] = {}

for i in program["inputs"]:
torch_tensor = torch.randn(
i["desc"]["shape"],
dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
)
torch_inputs[binary_name].append(torch_tensor)
for i in program["outputs"]:
torch_tensor = torch.zeros(
i["desc"]["shape"],
dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
for program_index in program_indices:
torch_inputs[binary_name][program_index] = []
torch_outputs[binary_name][program_index] = []

program = fbb_dict["programs"][program_index]
print(
f"running binary={binary_name} with program[{program_index}]:",
program["name"],
)
torch_outputs[binary_name].append(torch_tensor)

print("inputs:\n", torch_inputs)

total_inputs = []
total_outputs = []
for loop in range(arg_loops):
inputs = []
outputs = []
for i in torch_inputs[binary_name]:
inputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
toDataType(i.dtype),
)

for i in program["inputs"]:
torch_tensor = torch.randn(
i["desc"]["shape"],
dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
)
torch_inputs[binary_name][program_index].append(torch_tensor)
for i in program["outputs"]:
torch_tensor = torch.zeros(
i["desc"]["shape"],
dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
)
torch_outputs[binary_name][program_index].append(torch_tensor)

print("inputs:\n", torch_inputs)

total_inputs = []
total_outputs = []
for loop in range(arg_loops):
inputs = []
outputs = []
for i in torch_inputs[binary_name][program_index]:
inputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
toDataType(i.dtype),
)
)

for i in torch_outputs[binary_name]:
outputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
toDataType(i.dtype),
for i in torch_outputs[binary_name][program_index]:
outputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
toDataType(i.dtype),
)
)
)

total_inputs.append(inputs)
total_outputs.append(outputs)
total_inputs.append(inputs)
total_outputs.append(outputs)

for loop in range(arg_loops):
ttrt.runtime.submit(device, fbb, 0, total_inputs[loop], total_outputs[loop])
print(f"finished loop={loop}")
print("outputs:\n", torch_outputs)
for loop in range(arg_loops):
ttrt.runtime.submit(
device, fbb, program_index, total_inputs[loop], total_outputs[loop]
)
print(f"finished loop={loop}")
print("outputs:\n", torch_outputs)

# save artifacts
if arg_save_artifacts:
print("saving artifacts")
for binary in binaries:
copy_ttnn_binary_into_artifact(binary)
binary_name = os.path.splitext(os.path.basename(binary))[0]
torch_input_tensors = torch_inputs[binary_name]
torch_output_tensors = torch_outputs[binary_name]
fbb_dict = ttrt.binary.as_dict(ttrt.binary.load_binary_from_path(binary))
curr_program_indices = []

if arg_program_index != "all":
curr_program_indices.append(int(arg_program_index))
else:
curr_program_indices = [i for i in range(len(fbb_dict["programs"]))]

for program_index in curr_program_indices:
copy_ttnn_binary_into_artifact(binary)
binary_name = os.path.splitext(os.path.basename(binary))[0]
torch_input_tensors = torch_inputs[binary_name][program_index]
torch_output_tensors = torch_outputs[binary_name][program_index]

for i, input in enumerate(torch_input_tensors):
save_torch_tensor_into_ttrt_artifacts(
input, f"{binary_name}/program_{program_index}_input_{i}.pt"
)

for i, input in enumerate(torch_input_tensors):
save_torch_tensor_into_ttrt_artifacts(
input, f"{binary_name}/input_{i}.pt"
)
for i, output in enumerate(torch_output_tensors):
save_torch_tensor_into_ttrt_artifacts(
output, f"{binary_name}/program_{program_index}_output_{i}.pt"
)

for i, output in enumerate(torch_output_tensors):
save_torch_tensor_into_ttrt_artifacts(
output, f"{binary_name}/output_{i}.pt"
save_system_desc_into_ttrt_artifacts(
system_desc, f"{binary_name}/system_desc.ttsys"
)

save_system_desc_into_ttrt_artifacts(
system_desc, f"{binary_name}/system_desc.ttsys"
)


"""
API: query
Expand Down Expand Up @@ -278,7 +309,7 @@ def perf(args):

# acquire parameters
arg_binary = args.binary
arg_program_index = int(args.program_index)
arg_program_index = args.program_index
arg_clean_artifacts = args.clean_artifacts
arg_perf_csv = args.perf_csv
arg_loops = int(args.loops)
Expand Down
Loading