Skip to content

Commit

Permalink
[tools] supports mlir trunc by compare results automatically
Browse files Browse the repository at this point in the history
Firstly, use npz_tool.py to get compare.csv as compare results
npz_tool.py exp.npz got.npz -a -s compare.csv
Then use the following command to make a mlir file with a error graph
mlir_truncio.py -c compare.csv -i input.mlir -o output.mlir

Change-Id: Ieeeaa5dbdf812de05b506135f9c0051a1c643b17
  • Loading branch information
Watesoyan committed Nov 21, 2024
1 parent 82c9342 commit cd758f8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/numpy_helper/npz_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def parse_args(args_list):
help="Do statistics on int8 tensor for saturate ratio and low ratio")
parser.add_argument("--int8_tensor_close", type=int, default=1,
help="whether int8 tensor compare close")
parser.add_argument("--save", type=str, help="Save result as a csv file")
parser.add_argument("--save", '-s', type=str, help="Save result as a csv file")
parser.add_argument("--per_axis_compare", type=int, default=-1,
help="Compare along axis, usually along axis 1 as per-channel")
parser.add_argument("--fuzzy_match", action='store_true',
Expand Down
46 changes: 46 additions & 0 deletions python/tools/mlir_truncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================

def mlir_truncio(compare_result_file: str, input_mlir_file: str, output_mlir_file: str):
import pandas as pd
df = pd.read_csv(compare_result_file, sep=', ', header=0, engine='python')
names = df['name']
passes = df['pass']
inputs = list(names[passes == True])
outputs = list(names[passes == False])
if len(inputs) == len(passes):
print("[WARNING] no error")
return
cmd = f"tpuc-tool {input_mlir_file} --trunc-io='inputs="
for i in range(len(inputs)):
cmd += inputs[i]
if i != len(inputs) - 1:
cmd += ","
cmd += " outputs="
for i in range(len(outputs)):
cmd += outputs[i]
if i != len(outputs) - 1:
cmd += ","
cmd += f"' -o {output_mlir_file}"
import os
os.system(cmd)


if __name__ == '__main__':
# yapf: disable
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--comp_file", "-c", required=True, help="compare result file")
parser.add_argument("--input_mlir_file", "-i", type=str, required=True, help="input mlir file")
parser.add_argument("--output_mlir_file", "-o", type=str, required=True, help="output mlir file")
# yapf: enable
args = parser.parse_args()
mlir_truncio(args.comp_file, args.input_mlir_file, args.output_mlir_file)

0 comments on commit cd758f8

Please sign in to comment.