diff --git a/tuning/benchmark_dispatch.sh b/tuning/benchmark_dispatch.sh index dd8f29c..73a2fbd 100755 --- a/tuning/benchmark_dispatch.sh +++ b/tuning/benchmark_dispatch.sh @@ -12,9 +12,12 @@ readonly NAME="$(basename "$INPUT" .mlir)" # printf "Benchmarking $(basename ${INPUT}) on ${DEVICE}\n" +# Replace invalid characters in DEVICE variable +SANITIZED_DEVICE=$(echo "${DEVICE}" | sed 's/[^a-zA-Z0-9._-]/_/g') + timeout 16s ./tools/iree-benchmark-module --device="${DEVICE}" --module="${INPUT}" \ --hip_use_streams=true --hip_allow_inline_execution=true \ - --batch_size=1000 --benchmark_repetitions=3 > "${DIR}/benchmark_log_${DEVICE}.out" 2>&1 || (mv "$INPUT" "${DIR}/benchmark_failed" && exit 0) + --batch_size=1000 --benchmark_repetitions=3 > "${DIR}/benchmark_log_${SANITIZED_DEVICE}.out" 2>&1 || (mv "$INPUT" "${DIR}/benchmark_failed" && exit 0) -MEAN_TIME="$(grep --text real_time_mean "${DIR}/benchmark_log_${DEVICE}.out" | awk '{print $2}')" +MEAN_TIME="$(grep --text real_time_mean "${DIR}/benchmark_log_${SANITIZED_DEVICE}.out" | awk '{print $2}')" printf "%s\tMean Time: %.1f\n" "$(basename "$INPUT" .vmfb)" "$MEAN_TIME" diff --git a/tuning/tune.py b/tuning/candidate_gen.py similarity index 99% rename from tuning/tune.py rename to tuning/candidate_gen.py index 5ae981b..02f8eb9 100755 --- a/tuning/tune.py +++ b/tuning/candidate_gen.py @@ -24,7 +24,7 @@ from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen """ -Usage: ./tune.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk +Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk """ tune_logger = logging.getLogger("tune") diff --git a/tuning/compile_candidate.sh b/tuning/compile_candidate.sh index a50a1c6..9b9b74c 100755 --- a/tuning/compile_candidate.sh +++ b/tuning/compile_candidate.sh @@ -8,11 +8,11 @@ readonly DIR="$(dirname "$INPUT")" readonly BASENAME="$(basename "$INPUT" .mlir)" readonly OUT="${DIR}/compiled/${BASENAME}.vmfb" -mkdir -p "${DIR}/compiled" "${DIR}/failed" "${DIR}/configs" +mkdir -p "${DIR}/compiled" "${DIR}/failed" "${DIR}/specs" timeout 4s ./punet.sh "$INPUT" -o "$OUT" --compile-from=executable-sources 2>/dev/null || (mv "$INPUT" "$DIR/failed" && exit 1) tools/iree-dump-module "$OUT" | grep -q 'rocm-hsaco-fb' || (mv "$INPUT" "$DIR/failed" && rm -f "$OUT" && exit 1) if [ -f "${DIR}/${BASENAME}_config.mlir" ]; then - cat "${DIR}/../config_prolog.mlir" "${DIR}/${BASENAME}_config.mlir" "${DIR}/../config_epilog.mlir" > "${DIR}/configs/${BASENAME}_spec.mlir" + cat "${DIR}/../config_prolog.mlir" "${DIR}/${BASENAME}_config.mlir" "${DIR}/../config_epilog.mlir" > "${DIR}/specs/${BASENAME}_spec.mlir" fi echo "Compiling ${INPUT}: success" diff --git a/tuning/autotune.py b/tuning/libtuner.py similarity index 78% rename from tuning/autotune.py rename to tuning/libtuner.py index 4b3f7a6..8661354 100755 --- a/tuning/autotune.py +++ b/tuning/libtuner.py @@ -15,7 +15,7 @@ import time import multiprocessing import queue -import tune +import candidate_gen from tqdm import tqdm import re import hashlib @@ -24,23 +24,7 @@ import pickle import iree.runtime as ireert import random - -""" -Sample Usage: - -python autotune.py winograd 1286.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=1,3,5 --num-candidates=64 - - -Recommended Trial Run: - -python autotune.py winograd 1286.mlir --num-candidates=1 - - -Dry Run Test (no gpu requried): - -python autotune.py winograd 1286.mlir --num-candidates=64 --num-unet-candidates=10 --dry-run - -""" +from abc import ABC, abstractmethod # Default values for num_candidates and devices, change it as needed @@ -62,19 +46,19 @@ @dataclass class CandidateTracker: candidate_id: int - mlir_path: Optional[Path] = None - mlir_config_path: Optional[Path] = None - configuration: Optional[tune.Configuration] = None + dispatch_mlir_path: Optional[Path] = None + dispatch_config_path: Optional[Path] = None + configuration: Optional[candidate_gen.Configuration] = None compilation_successful: Optional[bool] = None - compiled_vmfb_path: Optional[Path] = None - compiled_vmfb_hash: Optional[str] = None + compiled_dispatch_path: Optional[Path] = None + compiled_dispatch_hash: Optional[str] = None first_benchmark_time: Optional[float] = None first_benchmark_device_id: Optional[int] = None - mlir_spec_path: Optional[Path] = None - unet_candidate_path: Optional[Path] = None - unet_vmfb_hash: Optional[str] = None - unet_benchmark_time: Optional[float] = None - unet_benchmark_device_id: Optional[int] = None + spec_path: Optional[Path] = None + model_path: Optional[Path] = None + compiled_model_hash: Optional[str] = None + model_benchmark_time: Optional[float] = None + model_benchmark_device_id: Optional[int] = None baseline_benchmark_time: Optional[float] = None calibrated_benchmark_diff: Optional[float] = None @@ -84,11 +68,7 @@ class PathConfig: # Preset constants global_config_prolog_mlir: Path = Path("./config_prolog.mlir") global_config_epilog_mlir: Path = Path("./config_epilog.mlir") - compile_candidate_sh: Path = Path("./compile_candidate.sh") - benchmark_dispatch_sh: Path = Path("./benchmark_dispatch.sh") - compile_unet_candidate_sh: Path = Path("./compile_unet_candidate.sh") - benchmark_unet_candidate_sh: Path = Path("./benchmark_unet_candidate.sh") - unet_baseline_vmfb: Path = Path("./unet_baseline.vmfb") + model_baseline_vmfb: Path = Path("./baseline.vmfb") # Dynamic paths base_dir: Path = field(init=False) @@ -98,7 +78,8 @@ class PathConfig: candidates_dir: Path = field(init=False) candidate_configs_pkl: Path = field(init=False) compiled_dir: Path = field(init=False) - compilefailed_dir: Path = field(init=False) + compile_failed_dir: Path = field(init=False) + spec_dir: Path = field(init=False) output_unilog: Path = field(init=False) result_summary_log: Path = field(init=False) @@ -121,7 +102,8 @@ def __post_init__(self): self, "candidate_configs_pkl", self.candidates_dir / "configs.pkl" ) object.__setattr__(self, "compiled_dir", self.candidates_dir / "compiled") - object.__setattr__(self, "compilefailed_dir", self.candidates_dir / "failed") + object.__setattr__(self, "compile_failed_dir", self.candidates_dir / "failed") + object.__setattr__(self, "spec_dir", self.candidates_dir / "specs") object.__setattr__(self, "output_unilog", self.base_dir / "output.log") object.__setattr__( self, "result_summary_log", self.base_dir / "result_summary.log" @@ -142,11 +124,40 @@ def get_candidate_mlir_path(self, candidate_id: int) -> Path: return self.candidates_dir / f"{candidate_id}.mlir" def get_candidate_spec_mlir_path(self, candidate_id: int) -> Path: - return self.candidates_dir / "configs" / f"{candidate_id}_spec.mlir" + return self.candidates_dir / "specs" / f"{candidate_id}_spec.mlir" def get_exe_format(self, path: Path) -> str: return f"./{path.as_posix()}" + def get_compiled_dispatch_index(self, file_path: Path) -> int: + return int(file_path.stem) + + def get_candidate_spec_filename(self, candidate_id: int) -> str: + return f"{candidate_id}_spec.mlir" + + def get_compiled_model_index(self, file_path: Path) -> int: + return int(file_path.stem.split("_")[-1]) + + +class TuningClient(ABC): + @abstractmethod + def get_dispatch_compile_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_dispatch_benchmark_command(self, candidate_tracker) -> list[str]: + pass + + @abstractmethod + def get_model_compile_command(self, candidate_tracker) -> list[str]: + pass + + @abstractmethod + def get_model_benchmark_command(self, candidate_tracker) -> list[str]: + pass + @dataclass class TaskTuple: @@ -209,7 +220,7 @@ def generate_sample_result( @dataclass -class UnetBenchmarkResult: +class ModelBenchmarkResult: result_str: Optional[str] = None def get_tokens(self) -> list[str]: @@ -221,15 +232,15 @@ def get_tokens(self) -> list[str]: except: return [] - def get_unet_candidate_path(self) -> Optional[str]: + def get_model_candidate_path(self) -> Optional[str]: if len(self.get_tokens()) < 2: return None return self.get_tokens()[1] def get_candidate_id(self) -> Optional[int]: - if self.get_unet_candidate_path(): + if self.get_model_candidate_path(): try: - path_str = self.get_unet_candidate_path() + path_str = self.get_model_candidate_path() return int(path_str.split("_")[-1].split(".")[0]) if path_str else None except ValueError: return None @@ -346,10 +357,10 @@ def validate_devices(user_devices: list[str]) -> None: class ExecutionPhases(str, Enum): dont_stop = "" generate_candidates = "generate-candidates" - compile_candidates = "compile-candidates" - benchmark_candidates = "benchmark-candidates" - compile_unet_candidates = "compile-unet-candidates" - benchmark_unet_candidates = "benchmark-unet-candidates" + compile_dispatches = "compile-dispatches" + benchmark_dispatches = "benchmark-dispatches" + compile_models = "compile-models" + benchmark_models = "benchmark-models" def parse_arguments() -> argparse.Namespace: @@ -386,7 +397,7 @@ def parse_arguments() -> argparse.Namespace: help="Stop execution after specified phase", ) parser.add_argument( - "--num-unet-candidates", + "--num-model-candidates", help="Maximum number of stage 2 candidates", type=int, default=50, @@ -397,12 +408,12 @@ def parse_arguments() -> argparse.Namespace: help="Do not attempt to run any modules or initialize the IREE runtime", ) - # tune.tune() options + # candidate_gen.tune() options parser.add_argument( "--num-candidates", type=int, default=DEFAULT_NUM_CANDIDATES, - help=f"Number of candidates to be generated by tune.py (default: {DEFAULT_NUM_CANDIDATES})", + help=f"Number of candidates to be generated by candidate_gen.py (default: {DEFAULT_NUM_CANDIDATES})", ) parser.add_argument( "--num-subgroups", @@ -466,7 +477,7 @@ def format(self, record): verbose_console_handler.setFormatter(file_formatter) logging.getLogger().addHandler(verbose_console_handler) - # config logger in tune.py + # config logger in candidate_gen.py tune_logger = logging.getLogger("tune") tune_logger.setLevel(logging.DEBUG) @@ -673,6 +684,11 @@ def load_pickle(file_path: Path) -> list[Any]: return loaded_array +def save_pickle(file_path: Path, input_list: list[Any]) -> None: + with open(file_path, "wb") as file: + pickle.dump(input_list, file) + + def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None: """Appends new content to the end of the output.log.""" title_str = "=" * 5 + f" {title} " + "=" * 5 + "\n" if title != "" else "" @@ -686,6 +702,7 @@ def generate_candidates( args: argparse.Namespace, path_config: PathConfig, candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ) -> list[int]: """Generate candidate files for tuning. Returns the list of candidate indexes""" logging.info("generate_candidates()") @@ -708,8 +725,8 @@ def generate_candidates( mlirs = [] try: - logging.debug("Captured messages from tune.py:") - tune.tune( + logging.debug("Captured messages from candidate_gen.py:") + candidate_gen.tune( input=str(path_config.template_mlir), output=str(path_config.candidates_dir), limit=args.num_candidates, @@ -723,14 +740,14 @@ def generate_candidates( ) except Exception as e: logging.error("An error occurred during candidates generation: %s", str(e)) - # Capture and log debug messages from tune.py + # Capture and log debug messages from candidate_gen.py tune_logger = logging.getLogger("tune") for handler in logging.getLogger().handlers: if isinstance(handler, logging.FileHandler): tune_logger.handlers.append(handler) - tune_logger.exception("Error in tune.py:") + tune_logger.exception("Error in candidate_gen.py:") raise - logging.debug("tune.py ends") + logging.debug("candidate_gen.py ends") candidate_configs = load_pickle(path_config.candidate_configs_pkl) candidate_configs.insert(0, None) # No Configuration class for 0.mlir @@ -743,19 +760,21 @@ def generate_candidates( candidates.append(int(mlir.stem)) new_candidate = CandidateTracker( candidate_id=int(mlir.stem), - mlir_path=mlir, + dispatch_mlir_path=mlir, configuration=candidate_configs[int(mlir.stem)], ) candidate_trackers.append(new_candidate) else: - candidate_trackers[int(mlir.stem.split("_config")[0])].mlir_config_path = ( - mlir - ) + candidate_trackers[ + int(mlir.stem.split("_config")[0]) + ].dispatch_config_path = mlir handle_error( condition=(len(candidates) == 0), msg="Failed to generate any candidates" ) + logging.critical(f"Generated [{len(candidates)}] candidates") + return candidates @@ -777,27 +796,29 @@ def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, lis return collision_detected, unique_indexes -def compile_candidates( +def compile_dispatches( args: argparse.Namespace, path_config: PathConfig, candidates: list[int], candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ) -> list[int]: """Compile candidate files for tuning and record in candidate_vmfbs.txt. Returns the list of compiled candidate indexes.""" logging.info("compile_candidates()") - task_list = [] - for candidate_index in candidates: - mlir_path = candidate_trackers[candidate_index].mlir_path - assert mlir_path is not None - command = [ - path_config.get_exe_format(path_config.compile_candidate_sh), - args.mode, - mlir_path.as_posix(), - ] - task_list.append(TaskTuple(args, command, check=False)) + if not candidates: + logging.info("No candidates to compile.") + return [] - num_worker = max(min(args.max_cpu_workers, len(task_list)), 1) # at least 1 worker + task_list = [ + TaskTuple( + args, + tuning_client.get_dispatch_compile_command(candidate_trackers[i]), + check=False, + ) + for i in candidates + ] + num_worker = min(args.max_cpu_workers, len(task_list)) multiprocess_progress_wrapper( num_worker=num_worker, task_list=task_list, function=run_command_wrapper ) @@ -806,7 +827,7 @@ def compile_candidates( path_config.compiled_dir.glob("*.vmfb"), key=numerical_sort_key ) failed_files = sorted( - path_config.compilefailed_dir.glob("*.mlir"), key=numerical_sort_key + path_config.compile_failed_dir.glob("*.mlir"), key=numerical_sort_key ) total, good, bad = len(task_list), len(compiled_files), len(failed_files) @@ -817,19 +838,19 @@ def compile_candidates( # Update candidate tracker for failed_file in failed_files: - index = int(failed_file.stem) + index = path_config.get_compiled_dispatch_index(failed_file) candidate_trackers[index].compilation_successful = False compiled_candidates = [] compiled_candidates_hash_list = [] for compiled_file in compiled_files: - index = int(compiled_file.stem) + index = path_config.get_compiled_dispatch_index(failed_file) compiled_candidates.append(index) candidate_trackers[index].compilation_successful = True - candidate_trackers[index].compiled_vmfb_path = compiled_file - compiled_vmfb_path = candidate_trackers[index].compiled_vmfb_path + candidate_trackers[index].compiled_dispatch_path = compiled_file + compiled_vmfb_path = candidate_trackers[index].compiled_dispatch_path assert compiled_vmfb_path is not None hash_val = calculate_md5(compiled_vmfb_path) - candidate_trackers[index].compiled_vmfb_hash = hash_val + candidate_trackers[index].compiled_dispatch_hash = hash_val compiled_candidates_hash_list.append((index, hash_val)) handle_error( @@ -854,6 +875,7 @@ def parse_dispatch_benchmark_results( path_config: PathConfig, benchmark_results: list[TaskResult], candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ) -> tuple[list[ParsedDisptachBenchmarkResult], list[str]]: benchmark_result_configs = [] dump_list = [] @@ -867,12 +889,12 @@ def parse_dispatch_benchmark_results( benchmark_time = res.get_benchmark_time() assert candidate_id is not None and benchmark_time is not None candidate_trackers[candidate_id].first_benchmark_time = benchmark_time - candidate_trackers[candidate_id].mlir_spec_path = ( - path_config.get_candidate_spec_mlir_path(candidate_id) + candidate_trackers[candidate_id].spec_path = ( + path_config.spec_dir / path_config.get_candidate_spec_filename(candidate_id) ) - mlir_path = candidate_trackers[candidate_id].mlir_path - mlir_spec_path = candidate_trackers[candidate_id].mlir_spec_path - assert mlir_path is not None and mlir_spec_path is not None + mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path + spec_path = candidate_trackers[candidate_id].spec_path + assert mlir_path is not None and spec_path is not None dump_list.append(res_str) benchmark_result_configs.append( @@ -881,7 +903,7 @@ def parse_dispatch_benchmark_results( candidate_id, benchmark_time, mlir_path, - mlir_spec_path, + spec_path, ) ) ) @@ -905,11 +927,12 @@ def generate_dryrun_dispatch_benchmark_results( return task_results -def benchmark_compiled_candidates( +def benchmark_dispatches( args: argparse.Namespace, path_config: PathConfig, compiled_candidates: list[int], candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ): """Benchmark the candidate files and store the topN results in file (best.log).""" logging.info("benchmark_top_candidates()") @@ -921,18 +944,15 @@ def benchmark_compiled_candidates( ) else: # Benchmarking dispatch candidates - task_list = [] - for index in compiled_candidates: - compiled_vmfb_path = candidate_trackers[index].compiled_vmfb_path - assert compiled_vmfb_path is not None - command = [ - path_config.get_exe_format(path_config.benchmark_dispatch_sh), - compiled_vmfb_path.as_posix(), - ] - task_list.append( - TaskTuple(args, command, check=False, command_need_device_id=True) + task_list = [ + TaskTuple( + args, + tuning_client.get_dispatch_benchmark_command(candidate_trackers[i]), + check=False, + command_need_device_id=True, ) - + for i in compiled_candidates + ] worker_context_queue = create_worker_context_queue(args.devices) benchmark_results = multiprocess_progress_wrapper( num_worker=len(args.devices), @@ -946,7 +966,7 @@ def benchmark_compiled_candidates( parsed_benchmark_results, dispatch_benchmark_dump_list, ) = parse_dispatch_benchmark_results( - path_config, benchmark_results, candidate_trackers + path_config, benchmark_results, candidate_trackers, tuning_client ) append_to_file( dispatch_benchmark_dump_list, @@ -966,7 +986,7 @@ def benchmark_compiled_candidates( # Select top candidates best_results = sorted( parsed_benchmark_results, key=lambda x: float(x.benchmark_time_in_seconds) - )[: args.num_unet_candidates] + )[: args.num_model_candidates] logging.critical(f"Selected top[{len(best_results)}]") dump_list = [ @@ -981,65 +1001,62 @@ def benchmark_compiled_candidates( return top_candidates -def compile_unet_candidates( +def compile_models( args: argparse.Namespace, path_config: PathConfig, candidates: list[int], candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ) -> list[int]: """Compile U-Net candidates stored in best.log. Return the list of U-Net candidate files.""" - logging.info("compile_unet_candidates()") + logging.info("compile_models()") if args.dry_run: return candidates - task_list = [] - for index in candidates: - if index == 0: - continue - mlir_spec_path = candidate_trackers[index].mlir_spec_path - assert mlir_spec_path is not None - command = [ - path_config.get_exe_format(path_config.compile_unet_candidate_sh), - args.mode, - mlir_spec_path.as_posix(), - ] - task_list.append(TaskTuple(args, command)) + if not candidates: + logging.info("No model candidates to compile.") + return [] - num_worker = max(min(args.max_cpu_workers, len(task_list)), 1) # at least 1 worker + task_list = [ + TaskTuple(args, tuning_client.get_model_compile_command(candidate_trackers[i])) + for i in candidates + if i != 0 + ] + num_worker = min(args.max_cpu_workers, len(task_list)) multiprocess_progress_wrapper( num_worker=num_worker, task_list=task_list, function=run_command_wrapper ) - unet_candidates_files = list(path_config.base_dir.glob("*.vmfb")) + model_candidates_files = list(path_config.base_dir.glob("*.vmfb")) - unet_candidates_indexes = [] - unet_candidates_hash_list = [] + model_candidates_indexes = [] + model_candidates_hash_list = [] # Update candidate tracker - for unet_candidate in unet_candidates_files: - assert unet_candidate is not None - index = int(unet_candidate.stem.split("_")[-1]) - candidate_trackers[index].unet_candidate_path = unet_candidate - hash_val = calculate_md5(unet_candidate) - candidate_trackers[index].unet_vmfb_hash = hash_val - unet_candidates_hash_list.append((index, hash_val)) - unet_candidates_indexes.append(index) - - # Check if unet candidate produces tbe same .vmfb - collision_detected, unique_unet_candidates_indexes = collision_handler( - unet_candidates_hash_list + for model_candidate in model_candidates_files: + assert model_candidate is not None + index = path_config.get_compiled_model_index(model_candidate) + candidate_trackers[index].model_path = model_candidate + hash_val = calculate_md5(model_candidate) + candidate_trackers[index].compiled_model_hash = hash_val + model_candidates_hash_list.append((index, hash_val)) + model_candidates_indexes.append(index) + + # Check if model candidate produces tbe same .vmfb + collision_detected, unique_model_candidates_indexes = collision_handler( + model_candidates_hash_list ) if collision_detected: logging.critical( - f"Remains [{len(unique_unet_candidates_indexes)}] unique candidate indexes" + f"Remains [{len(unique_model_candidates_indexes)}] unique candidate indexes" ) return ( - unique_unet_candidates_indexes + unique_model_candidates_indexes if collision_detected - else unet_candidates_indexes + else model_candidates_indexes ) @@ -1097,20 +1114,20 @@ def parse_grouped_benchmark_results( for same_device_results in grouped_benchmark_results: dump_unsort_list: list[tuple[float, str]] = [] - for unet_candidate_result in same_device_results: + for model_candidate_result in same_device_results: # Skip if benchmark failed. - result_str = unet_candidate_result.result.stdout + result_str = model_candidate_result.result.stdout if result_str is None: continue - res = UnetBenchmarkResult(result_str) + res = ModelBenchmarkResult(result_str) device_id = res.get_device_id() # Record baseline benchmarking result. - unet_candidate_path = res.get_unet_candidate_path() + model_candidate_path = res.get_model_candidate_path() if ( - unet_candidate_path is not None - and str(path_config.unet_baseline_vmfb) in unet_candidate_path + model_candidate_path is not None + and str(path_config.model_baseline_vmfb) in model_candidate_path ): baseline_time = res.get_benchmark_time() if baseline_time is None: @@ -1126,8 +1143,8 @@ def parse_grouped_benchmark_results( if candidate_time is None: incomplete_list.append((c_id, device_id)) continue - candidate_trackers[c_id].unet_benchmark_time = candidate_time - candidate_trackers[c_id].unet_benchmark_device_id = device_id + candidate_trackers[c_id].model_benchmark_time = candidate_time + candidate_trackers[c_id].model_benchmark_device_id = device_id # Skip improvement calculation if no baseline data. if baseline_time is None: dump_unsort_list.append((candidate_time, result_str)) @@ -1142,7 +1159,7 @@ def parse_grouped_benchmark_results( assert dump_str is not None dump_unsort_list.append((candidate_time, dump_str)) - # Sort unet candidate benchmarking result str in ascending time order. + # Sort model candidate benchmarking result str in ascending time order. dump_list = dump_list + [ dump_str for _, dump_str in sorted(dump_unsort_list, key=lambda x: x[0]) ] @@ -1150,9 +1167,9 @@ def parse_grouped_benchmark_results( # Store incomplete .vmfb file at the end of dump_list. for index, device_id in incomplete_list: index_to_path = lambda index: ( - f"{path_config.unet_baseline_vmfb.as_posix()}" + f"{path_config.model_baseline_vmfb.as_posix()}" if index == 0 - else f"{candidate_trackers[index].unet_candidate_path}" + else f"{candidate_trackers[index].model_path}" ) error_msg = f"Benchmarking result of {index_to_path(index)} on deivce {device_id} is incomplete" handle_error(condition=True, msg=error_msg, level=logging.WARNING) @@ -1172,7 +1189,7 @@ def generate_dryrun_unet_benchmark_results( task_result = subprocess.CompletedProcess( args=[""], returncode=0, - stdout=UnetBenchmarkResult().generate_sample_result( + stdout=ModelBenchmarkResult().generate_sample_result( candidate_vmfb_path_str=candidate_vmfb_path.as_posix(), device_id=device_id, t1=start, @@ -1190,7 +1207,7 @@ def dryrun_benchmark_unet( candidate_trackers: list[CandidateTracker], ): - unet_vmfb_paths = [path_config.unet_baseline_vmfb] + [ + unet_vmfb_paths = [path_config.model_baseline_vmfb] + [ Path(f"unet_candidate_{index}.vmfb") for index in unet_candidates ] benchmark_results = generate_dryrun_unet_benchmark_results(unet_vmfb_paths) @@ -1205,39 +1222,33 @@ def dryrun_benchmark_unet( ) -def benchmark_unet( +def benchmark_models( args: argparse.Namespace, path_config: PathConfig, - unet_candidates: list[int], + model_candidates: list[int], candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ): """Benchmark U-Net candidate files and log the results.""" - logging.info("benchmark_unet()") + logging.info("benchmark_models()") if args.dry_run: - dryrun_benchmark_unet(path_config, unet_candidates, candidate_trackers) + dryrun_benchmark_unet(path_config, model_candidates, candidate_trackers) return - # Benchmarking unet candidates + # Benchmarking model candidates worker_context_queue = create_worker_context_queue(args.devices) - benchmark_task_list = [] - for index in unet_candidates: - unet_candidate_path = candidate_trackers[index].unet_candidate_path - assert unet_candidate_path is not None - command = [ - path_config.get_exe_format(path_config.benchmark_unet_candidate_sh), - unet_candidate_path.as_posix(), - ] - benchmark_task_list.append( - TaskTuple( - args, - command, - check=False, - command_need_device_id=True, - cooling_time=10, - result_need_device_id=True, - ) + benchmark_task_list = [ + TaskTuple( + args, + tuning_client.get_model_benchmark_command(candidate_trackers[i]), + check=False, + command_need_device_id=True, + cooling_time=10, + result_need_device_id=True, ) + for i in model_candidates + ] benchmark_results = multiprocess_progress_wrapper( num_worker=len(args.devices), task_list=benchmark_task_list, @@ -1249,14 +1260,12 @@ def benchmark_unet( grouped_benchmark_results = group_benchmark_results_by_device_id(benchmark_results) # Benchmarking baselines on each involved device + candidate_trackers[0].model_path = path_config.model_baseline_vmfb worker_context_queue = create_worker_context_queue(args.devices) baseline_task_list = [ TaskTuple( args, - command=[ - path_config.get_exe_format(path_config.benchmark_unet_candidate_sh), - path_config.unet_baseline_vmfb.as_posix(), - ], + tuning_client.get_model_benchmark_command(candidate_trackers[0]), check=False, command_need_device_id=True, result_need_device_id=True, @@ -1276,13 +1285,13 @@ def benchmark_unet( [x] + y for x, y in zip(baseline_results, grouped_benchmark_results) ] - # Update candidate_tracker and extract strings which will be stored in unet_result_log + # Update candidate_tracker and extract strings which will be stored later dump_list = parse_grouped_benchmark_results( path_config, grouped_benchmark_results, candidate_trackers ) append_to_file( - dump_list, filepath=path_config.output_unilog, title="Unet Benchmark Results" + dump_list, filepath=path_config.output_unilog, title="Model Benchmark Results" ) @@ -1292,10 +1301,10 @@ def summerize_top_candidates( dump_list = [] top_candidates = [] for candidate in candidate_trackers: - if candidate.candidate_id == 0 or candidate.unet_benchmark_time is None: + if candidate.candidate_id == 0 or candidate.model_benchmark_time is None: continue top_candidates.append( - (candidate.candidate_id, candidate.unet_benchmark_time) + (candidate.candidate_id, candidate.model_benchmark_time) ) # collect (id, time) top_candidates = sorted( @@ -1305,88 +1314,11 @@ def summerize_top_candidates( for candidate_id in top_candidate_ids: candidate = candidate_trackers[candidate_id] - assert candidate.mlir_config_path is not None - with open(candidate.mlir_config_path, "r") as file: + assert candidate.dispatch_config_path is not None + with open(candidate.dispatch_config_path, "r") as file: config_file_contents = file.read() - final_str = f"Candidate {candidate.candidate_id}:\nUnet benchmark time: {candidate.unet_benchmark_time} on device {candidate.unet_benchmark_device_id}\nDispatch benchmark time: {candidate.first_benchmark_time} on device {candidate.unet_benchmark_device_id}\nSpec file path: {candidate.mlir_spec_path}\nSpec contents:{config_file_contents}\n\n" + final_str = f"Candidate {candidate.candidate_id}:\nModel benchmark time: {candidate.model_benchmark_time} on device {candidate.model_benchmark_device_id}\nDispatch benchmark time: {candidate.first_benchmark_time} on device {candidate.model_benchmark_device_id}\nSpec file path: {candidate.spec_path}\nSpec contents:{config_file_contents}\n\n" dump_list.append(final_str) with open(path_config.result_summary_log, "w") as file: file.writelines(dump_list) - - -def autotune(args: argparse.Namespace) -> None: - path_config = PathConfig() - path_config.base_dir.mkdir(parents=True, exist_ok=True) - path_config.output_unilog.touch() - - candidate_trackers: list[CandidateTracker] = [] - stop_after_phase: str = args.stop_after - - print("Setup logging") - setup_logging(args, path_config) - print(path_config.run_log, end="\n\n") - - print("Validating devices") - validate_devices(args.devices) - print("Validation successful!\n") - - print("Generating candidates...") - candidates = generate_candidates(args, path_config, candidate_trackers) - print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n") - if stop_after_phase == ExecutionPhases.generate_candidates: - return - - print("Compiling candidates...") - compiled_candidates = compile_candidates( - args, path_config, candidates, candidate_trackers - ) - print(f"Compiled files are stored in {path_config.compiled_dir}\n") - if stop_after_phase == ExecutionPhases.compile_candidates: - return - - print("Benchmarking compiled candidates...") - top_candidates = benchmark_compiled_candidates( - args, path_config, compiled_candidates, candidate_trackers - ) - print(f"Stored results in {path_config.output_unilog}\n") - - if stop_after_phase == ExecutionPhases.benchmark_candidates: - return - - print(f"Compiling top unet candidates...") - unet_candidates = compile_unet_candidates( - args, path_config, top_candidates, candidate_trackers - ) - print(f"Unet candidates compiled in {path_config.base_dir}\n") - if stop_after_phase == ExecutionPhases.compile_unet_candidates: - return - - print("Benchmarking unet candidates...") - benchmark_unet(args, path_config, unet_candidates, candidate_trackers) - print(f"Stored results in {path_config.output_unilog}") - if stop_after_phase == ExecutionPhases.benchmark_unet_candidates: - return - - summerize_top_candidates(path_config, candidate_trackers) - print(f"Stored top candidates info in {path_config.result_summary_log}\n") - - with open(path_config.candidate_trackers_pkl, "wb") as file: - pickle.dump(candidate_trackers, file) - print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") - - print("Check the detailed execution logs in:") - print(path_config.run_log) - - for candidate in candidate_trackers: - logging.debug(candidate) - if args.verbose: - print(candidate) - - -def main(): - autotune(parse_arguments()) - - -if __name__ == "__main__": - main() diff --git a/tuning/punet_autotune.py b/tuning/punet_autotune.py new file mode 100644 index 0000000..014c76d --- /dev/null +++ b/tuning/punet_autotune.py @@ -0,0 +1,151 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import libtuner +from pathlib import Path + + +""" +Sample Usage: + +python punet_autotune.py winograd 1286.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64 + + +Recommended Trial Run: + +python punet_autotune.py winograd 1286.mlir --num-candidates=1 + + +Dry Run Test (no gpu requried): + +python punet_autotune.py winograd 1286.mlir --num-candidates=64 --num-model-candidates=10 --dry-run + +""" + + +class PunetClient(libtuner.TuningClient): + + def get_dispatch_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + mlir_path = candidate_tracker.dispatch_mlir_path + assert mlir_path is not None + command = [ + "./compile_candidate.sh", + "winograd", + mlir_path.as_posix(), + ] + return command + + def get_dispatch_benchmark_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + compiled_vmfb_path = candidate_tracker.compiled_dispatch_path + assert compiled_vmfb_path is not None + command = [ + "./benchmark_dispatch.sh", + compiled_vmfb_path.as_posix(), + ] + return command + + def get_model_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + mlir_spec_path = candidate_tracker.spec_path + assert mlir_spec_path is not None + command = [ + "./compile_unet_candidate.sh", + "winograd", + mlir_spec_path.as_posix(), + ] + return command + + def get_model_benchmark_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + unet_candidate_path = candidate_tracker.model_path + assert unet_candidate_path is not None + command = [ + "./benchmark_unet_candidate.sh", + unet_candidate_path.as_posix(), + ] + return command + + +def main(): + args = libtuner.parse_arguments() + path_config = libtuner.PathConfig() + path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() + candidate_trackers: list[libtuner.CandidateTracker] = [] + punet_client = PunetClient() + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidates( + args, path_config, candidate_trackers, punet_client + ) + print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile_dispatches( + args, path_config, candidates, candidate_trackers, punet_client + ) + print(f"Compiled files are stored in {path_config.compiled_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark_dispatches( + args, path_config, compiled_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}\n") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + print(f"Compiling top model candidates...") + punet_candidates = libtuner.compile_models( + args, path_config, top_candidates, candidate_trackers, punet_client + ) + print(f"Model candidates compiled in {path_config.base_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_models: + return + + print("Benchmarking model candidates...") + libtuner.benchmark_models( + args, path_config, punet_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_models: + return + + libtuner.summerize_top_candidates(path_config, candidate_trackers) + print(f"Stored top candidates info in {path_config.result_summary_log}\n") + + libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) + print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") + + print("Check the detailed execution logs in:") + print(path_config.run_log) + + for candidate in candidate_trackers: + libtuner.logging.debug(candidate) + if args.verbose: + print(candidate) + + +if __name__ == "__main__": + main() diff --git a/tuning/test_tune.py b/tuning/test_candidate_gen.py similarity index 53% rename from tuning/test_tune.py rename to tuning/test_candidate_gen.py index 3e52863..ad9b97e 100644 --- a/tuning/test_tune.py +++ b/tuning/test_candidate_gen.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import pytest -import tune +import candidate_gen """ Usage: python -m pytest test_tune.py @@ -13,30 +13,54 @@ def test_get_shaped_type_element_bitwidth(): - assert tune.ShapedType([1024, 2048], tune.ElementType.i8).bitwidth == 8 - assert tune.ShapedType([2048], tune.ElementType.i32).bitwidth == 32 - assert tune.ShapedType([2048, 512, 384], tune.ElementType.f8).bitwidth == 8 - assert tune.ShapedType([1, 1], tune.ElementType.f16).bitwidth == 16 + assert ( + candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 + ) + assert ( + candidate_gen.ShapedType( + [2048, 512, 384], candidate_gen.ElementType.f8 + ).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + ) def test_get_shaped_type_to_str(): - assert str(tune.ShapedType([1024, 2048], tune.ElementType.i8)) == "1024x2048xi8" - assert str(tune.ShapedType([1024], tune.ElementType.f32)) == "1024xf32" - assert str(tune.ShapedType([1, 2, 3], tune.ElementType.f16)) == "1x2x3xf16" - assert str(tune.ShapedType([-1, 2, 3], tune.ElementType.f16)) == "?x2x3xf16" + assert ( + str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) + == "1024x2048xi8" + ) + assert ( + str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) + == "1024xf32" + ) + assert ( + str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) + == "1x2x3xf16" + ) + assert ( + str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) + == "?x2x3xf16" + ) def test_parse_tensor_type(): - assert tune.parse_tensor_type("tensor<1x2x3xf32>") == tune.ShapedType( - [1, 2, 3], tune.ElementType.f32 - ) - assert tune.parse_tensor_type("tensor<123xi8>") == tune.ShapedType( - [123], tune.ElementType.i8 - ) + assert candidate_gen.parse_tensor_type( + "tensor<1x2x3xf32>" + ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) + assert candidate_gen.parse_tensor_type( + "tensor<123xi8>" + ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) def test_get_mmt_tile_sizes(): - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=0, workgroup_size=[], intrinsic="", @@ -45,11 +69,11 @@ def test_get_mmt_tile_sizes(): subgroup_n_count=0, waves_per_eu=0, ) - assert tune.get_mmt_tile_sizes(config) == [128, 320, 32] + assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] def test_get_conv_tile_sizes(): - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], intrinsic="#iree_gpu.mma_layout", @@ -58,11 +82,11 @@ def test_get_conv_tile_sizes(): subgroup_n_count=4, waves_per_eu=1, ) - assert tune.get_conv_tile_sizes(config) == [1, 1, 464, 320, 1, 1, 16] + assert candidate_gen.get_conv_tile_sizes(config) == [1, 1, 464, 320, 1, 1, 16] def test_get_contract_tile_sizes(): - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], intrinsic="", @@ -71,14 +95,18 @@ def test_get_contract_tile_sizes(): subgroup_n_count=1, waves_per_eu=2, ) - assert tune.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert tune.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert tune.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert tune.get_contract_tile_sizes(config, ["k", "k", "k"]) == [16, 16, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + 16, + 16, + 16, + ] def test_get_pipeline_config(): - config1 = tune.Configuration( + config1 = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], intrinsic="", @@ -87,7 +115,7 @@ def test_get_pipeline_config(): subgroup_n_count=1, waves_per_eu=2, ) - config2 = tune.Configuration( + config2 = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], intrinsic="", @@ -96,9 +124,9 @@ def test_get_pipeline_config(): subgroup_n_count=1, waves_per_eu=4, ) - assert tune.get_pipeline_config(config1) == ", prefetch_shared_memory" + assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" assert ( - tune.get_pipeline_config(config2) + candidate_gen.get_pipeline_config(config2) == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' ) @@ -110,12 +138,12 @@ def test_get_shapes_mmt(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"^bb0(%in: f16, %in_0: f16, %out: f32):", ] - assert tune.get_shapes_mmt(template) == tune.ProblemSize( - tune.MatmulSize(2048, 1280, 1280), - tune.ShapedType([2048, 1280], tune.ElementType.f16), - tune.ShapedType([1280, 1280], tune.ElementType.f16), - tune.ShapedType([2048, 1280], tune.ElementType.f32), - tune.DispatchKind.mmt, + assert candidate_gen.get_shapes_mmt(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, ) @@ -125,12 +153,12 @@ def test_get_shapes_conv(): r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", ] - assert tune.get_shapes_conv(template) == tune.ProblemSize( - tune.MatmulSize(32, 256, 11520), - tune.ShapedType([1, 3, 34, 1280], tune.ElementType.f16), - tune.ShapedType([3, 3, 1280, 256], tune.ElementType.f16), - tune.ShapedType([1, 1, 32, 256], tune.ElementType.f32), - tune.DispatchKind.conv, + assert candidate_gen.get_shapes_conv(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 256, 11520), + candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, ) @@ -141,12 +169,14 @@ def test_get_shapes_contract(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"^bb0(%in: f16, %in_0: f16, %out: f32):", ] - assert tune.get_shapes_contract(template, "mk", "nk") == tune.ProblemSize( - tune.MatmulSize(2048, 1280, 1280), - tune.ShapedType([2048, 1280], tune.ElementType.f16), - tune.ShapedType([1280, 1280], tune.ElementType.f16), - tune.ShapedType([2048, 1280], tune.ElementType.f32), - tune.DispatchKind.contraction, + assert candidate_gen.get_shapes_contract( + template, "mk", "nk" + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, ) @@ -156,12 +186,14 @@ def test_get_shapes_batch_matmul(): "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", ] - assert tune.get_shapes_batch_matmul(template, "bmk", "bkn") == tune.ProblemSize( - tune.MatmulSize(32, 32, 1024, 1), - tune.ShapedType([1, 32, 1024], tune.ElementType.f32), - tune.ShapedType([1, 1024, 32], tune.ElementType.f32), - tune.ShapedType([1, 32, 32], tune.ElementType.f32), - tune.DispatchKind.batch_matmul, + assert candidate_gen.get_shapes_batch_matmul( + template, "bmk", "bkn" + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 32, 1024, 1), + candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, ) @@ -171,122 +203,138 @@ def test_get_shapes_batch_mmt(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", ] - assert tune.get_shapes_batch_mmt(template) == tune.ProblemSize( - tune.MatmulSize(4096, 640, 640, 2), - tune.ShapedType([2, 4096, 640], tune.ElementType.i8), - tune.ShapedType([2, 640, 640], tune.ElementType.i8), - tune.ShapedType([2, 4096, 640], tune.ElementType.i32), - tune.DispatchKind.batch_mmt, + assert candidate_gen.get_shapes_batch_mmt(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, ) def test_mfma_intrinsic_to_str(): - assert str(tune.MfmaIntrinsic.mfma_f16_16x16x16_f32()) == "MFMA_F16_16x16x16_F32" - assert str(tune.MfmaIntrinsic.mfma_i8_32x32x16_i32()) == "MFMA_I8_32x32x16_I32" + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) + == "MFMA_F16_16x16x16_F32" + ) + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) + == "MFMA_I8_32x32x16_I32" + ) def test_get_compatible_mfma_intrinsics(): - assert tune.get_compatible_mfma_intrinsics( - tune.ProblemSize( - tune.MatmulSize(2048, 1280, 1280), - tune.ShapedType([2048, 1280], tune.ElementType.f16), - tune.ShapedType([1280, 1280], tune.ElementType.f16), - tune.ShapedType([2048, 1280], tune.ElementType.f32), - tune.DispatchKind.mmt, + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, ) ) == [ - tune.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tune.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), ] - assert tune.get_compatible_mfma_intrinsics( - tune.ProblemSize( - tune.MatmulSize(2048, 1280, 1280), - tune.ShapedType([2048, 1280], tune.ElementType.i8), - tune.ShapedType([1280, 1280], tune.ElementType.i8), - tune.ShapedType([2048, 1280], tune.ElementType.i32), - tune.DispatchKind.mmt, + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.mmt, ) ) == [ - tune.MfmaIntrinsic.mfma_i8_16x16x32_i32(), - tune.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), + candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), ] - assert tune.get_compatible_mfma_intrinsics( - tune.ProblemSize( - tune.MatmulSize(968, 320, 640, 64), - tune.ShapedType([64, 968, 640], tune.ElementType.f32), - tune.ShapedType([64, 640, 320], tune.ElementType.f32), - tune.ShapedType([64, 968, 320], tune.ElementType.f32), - tune.DispatchKind.batch_matmul, + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, ) ) == [ - tune.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tune.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), ] def test_generate_solutions(): - matmul_size = tune.MatmulSize(2048, 3840, 1280) - lhs_type = tune.ShapedType([2048, 1280], tune.ElementType.f16) - rhs_type = tune.ShapedType([3840, 1280], tune.ElementType.f16) - res_type = tune.ShapedType([2048, 3840], tune.ElementType.f32) - problem_size = tune.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, tune.DispatchKind.mmt - ) - configs = tune.generate_solutions(problem_size, 4) + matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) + lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + configs = candidate_gen.generate_solutions(problem_size, 4) assert configs is not None def test_calculate_shared_memory_usage_in_bytes(): - matmul_size = tune.MatmulSize(1024, 1024, 1024) - lhs_type = tune.ShapedType([1024, 1024], tune.ElementType.f16) - rhs_type = tune.ShapedType([1024, 1024], tune.ElementType.f16) - res_type = tune.ShapedType([1024, 1024], tune.ElementType.f32) - problem_size = tune.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, tune.DispatchKind.mmt + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) assert ( - tune.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 147456 ) - lhs_type = tune.ShapedType([1024, 1024], tune.ElementType.i8) - problem_size = tune.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, tune.DispatchKind.mmt + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) assert ( - tune.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 81920 + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 81920 ) - rhs_type = tune.ShapedType([1024, 1024], tune.ElementType.i32) - problem_size = tune.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, tune.DispatchKind.mmt + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) assert ( - tune.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) == 12288 + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) + == 12288 ) def test_generate_constraints_valid_input(): - matmul_size = tune.MatmulSize(1024, 1024, 1024) - lhs_type = tune.ShapedType([1024, 1024], tune.ElementType.f16) - rhs_type = tune.ShapedType([1024, 1024], tune.ElementType.f16) - res_type = tune.ShapedType([1024, 1024], tune.ElementType.f32) - problem_size = tune.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, tune.DispatchKind.mmt + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) # Define input parameters as z3 Ints - m, n, k = tune.z3.Int("m"), tune.z3.Int("n"), tune.z3.Int("k") - subgroup_size = tune.z3.Int("subgroup_size") - intrinsic_mn = tune.z3.Int("intrinsic_mn") - intrinsic_k = tune.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = tune.z3.Int("wg_x"), tune.z3.Int("wg_y"), tune.z3.Int("wg_z") - sg_m_cnt = tune.z3.Int("sg_m_cnt") - sg_n_cnt = tune.z3.Int("sg_n_cnt") - waves_per_eu = tune.z3.Int("waves_per_eu") - - constraints = tune.generate_constraints( + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( problem_size, [m, n, k], 4, @@ -298,32 +346,40 @@ def test_generate_constraints_valid_input(): waves_per_eu, ) - solver = tune.z3.Solver() + solver = candidate_gen.z3.Solver() solver.add(constraints) # Check if the constraints are satisfiable - assert solver.check() == tune.z3.sat + assert solver.check() == candidate_gen.z3.sat def test_generate_constraints_invalid_input(): # Define input parameters that should lead to unsatisfiable constraints - matmul_size = tune.MatmulSize(1024, 1024, 1024) - lhs_type = tune.ShapedType([1024, 1024], tune.ElementType.f16) - rhs_type = tune.ShapedType([1024, 1024], tune.ElementType.f16) - res_type = tune.ShapedType([1024, 1024], tune.ElementType.f32) - problem_size = tune.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, tune.DispatchKind.mmt - ) - m, n, k = tune.z3.Int("m"), tune.z3.Int("n"), tune.z3.Int("k") - subgroup_size = tune.z3.Int("subgroup_size") - intrinsic_mn = tune.z3.Int("intrinsic_mn") - intrinsic_k = tune.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = tune.z3.Int("wg_x"), tune.z3.Int("wg_y"), tune.z3.Int("wg_z") - sg_m_cnt = tune.z3.Int("sg_m_cnt") - sg_n_cnt = tune.z3.Int("sg_n_cnt") - waves_per_eu = tune.z3.Int("waves_per_eu") - - constraints = tune.generate_constraints( + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( problem_size, [m, n, k], 4, @@ -336,11 +392,11 @@ def test_generate_constraints_invalid_input(): ) constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - solver = tune.z3.Solver() + solver = candidate_gen.z3.Solver() solver.add(constraints) # Check if the constraints are unsatisfiable - assert solver.check() == tune.z3.unsat + assert solver.check() == candidate_gen.z3.unsat def test_apply_params_mmt(): @@ -353,24 +409,26 @@ def test_apply_params_mmt(): M, N, K = 2048, 1280, 1280 - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=tune.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, waves_per_eu=8, ) - problem_size = tune.ProblemSize( - tune.MatmulSize(M, N, K), - tune.ShapedType([M, K], tune.ElementType.f16), - tune.ShapedType([N, K], tune.ElementType.f16), - tune.ShapedType([M, N], tune.ElementType.f32), - tune.DispatchKind.mmt, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + modified, embeddable = candidate_gen.apply_params_mmt( + problem_size, mlir_template, config ) - modified, embeddable = tune.apply_params_mmt(problem_size, mlir_template, config) assert modified assert embeddable @@ -396,24 +454,28 @@ def test_apply_params_conv(): n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=tune.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, waves_per_eu=2, ) - problem_size = tune.ProblemSize( - tune.MatmulSize(oh * ow, oc, fh * fw * ic), - tune.ShapedType([n, oh + 2, ow + 2, oc], tune.ElementType.f16), - tune.ShapedType([fh, fw, ic, oc], tune.ElementType.f16), - tune.ShapedType([n, oh, ow, oc], tune.ElementType.f32), - tune.DispatchKind.conv, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), + candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + modified, embeddable = candidate_gen.apply_params_conv( + problem_size, mlir_template, config ) - modified, embeddable = tune.apply_params_conv(problem_size, mlir_template, config) assert modified assert embeddable @@ -438,25 +500,25 @@ def test_apply_params_contract(): ] tile_dims = "*mnk" - problem_size = tune.ProblemSize( - tune.MatmulSize(2048, 3840, 1280), - tune.ShapedType([2, 1024, 1280], tune.ElementType.f16), - tune.ShapedType([3, 20, 64, 1280], tune.ElementType.f16), - tune.ShapedType([3, 2, 20, 1024, 64], tune.ElementType.f32), - tune.DispatchKind.contraction, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 3840, 1280), + candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, ) - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=tune.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, waves_per_eu=2, ) - new_mlir, _embeddable = tune.apply_params_contract( + new_mlir, _embeddable = candidate_gen.apply_params_contract( problem_size, tile_dims, mlir_template, config ) @@ -482,25 +544,25 @@ def test_apply_params_batch_matmul(): ] tile_dims = "bmnk" - problem_size = tune.ProblemSize( - tune.MatmulSize(968, 320, 640, 64), - tune.ShapedType([64, 968, 640], tune.ElementType.f16), - tune.ShapedType([64, 640, 320], tune.ElementType.f16), - tune.ShapedType([64, 968, 320], tune.ElementType.f32), - tune.DispatchKind.batch_matmul, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, ) - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=tune.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, waves_per_eu=2, ) - modified, embeddable = tune.apply_params_batch_matmul( + modified, embeddable = candidate_gen.apply_params_batch_matmul( problem_size, tile_dims, mlir_template, config ) @@ -526,25 +588,25 @@ def test_apply_params_batch_mmt_float(): '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = tune.ProblemSize( - tune.MatmulSize(4096, 640, 640, 2), - tune.ShapedType([2, 4096, 640], tune.ElementType.f16), - tune.ShapedType([2, 640, 640], tune.ElementType.f16), - tune.ShapedType([2, 4096, 640], tune.ElementType.f32), - tune.DispatchKind.batch_mmt, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_mmt, ) - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=tune.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, waves_per_eu=2, ) - modified, embeddable = tune.apply_params_batch_mmt( + modified, embeddable = candidate_gen.apply_params_batch_mmt( problem_size, mlir_template, config ) @@ -570,25 +632,25 @@ def test_apply_params_batch_mmt_int(): '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = tune.ProblemSize( - tune.MatmulSize(4096, 640, 640, 2), - tune.ShapedType([2, 4096, 640], tune.ElementType.i8), - tune.ShapedType([2, 640, 640], tune.ElementType.i8), - tune.ShapedType([2, 4096, 640], tune.ElementType.i32), - tune.DispatchKind.batch_mmt, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, ) - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=tune.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, waves_per_eu=4, ) - modified, embeddable = tune.apply_params_batch_mmt( + modified, embeddable = candidate_gen.apply_params_batch_mmt( problem_size, mlir_template, config ) @@ -635,25 +697,25 @@ def test_apply_params_broadcast_rhs_mmt(): '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = tune.ProblemSize( - tune.MatmulSize(4096, 640, 640, 2), - tune.ShapedType([2, 4096, 640], tune.ElementType.i8), - tune.ShapedType([640, 640], tune.ElementType.i8), - tune.ShapedType([2, 4096, 640], tune.ElementType.i32), - tune.DispatchKind.broadcast_rhs_mmt, + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.broadcast_rhs_mmt, ) - config = tune.Configuration( + config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=tune.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, waves_per_eu=4, ) - modified, embeddable = tune.apply_params_broadcast_rhs_mmt( + modified, embeddable = candidate_gen.apply_params_broadcast_rhs_mmt( problem_size, mlir_template, config ) @@ -702,7 +764,7 @@ def test_detect_broadcast_rhs_mmt(): r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', ] - assert tune.is_broadcast_rhs_mmt(mlir_lines) + assert candidate_gen.is_broadcast_rhs_mmt(mlir_lines) def test_parse_mlir(): @@ -714,7 +776,9 @@ def test_parse_mlir(): } } """ - mlir_module = tune.parse_mlir(mlir_str) + mlir_module = candidate_gen.parse_mlir(mlir_str) assert mlir_module != None - assert isinstance(mlir_module, tune.ireec._mlir_libs._mlir.ir.Module) - assert isinstance(mlir_module.body.operations[0], tune.ireec.dialects.func.FuncOp) + assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) + assert isinstance( + mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp + ) diff --git a/tuning/test_autotune.py b/tuning/test_libtuner.py similarity index 72% rename from tuning/test_autotune.py rename to tuning/test_libtuner.py index 4548456..394208c 100644 --- a/tuning/test_autotune.py +++ b/tuning/test_libtuner.py @@ -7,22 +7,22 @@ import argparse import pytest from unittest.mock import call, patch, MagicMock -import autotune +import libtuner """ -Usage: python -m pytest test_autotune.py +Usage: python -m pytest test_libtuner.py """ def test_group_benchmark_results_by_device_id(): - def generate_res(res_arg: str, device_id: int) -> autotune.TaskResult: - result: autotune.subprocess.CompletedProcess = ( - autotune.subprocess.CompletedProcess( + def generate_res(res_arg: str, device_id: int) -> libtuner.TaskResult: + result: libtuner.subprocess.CompletedProcess = ( + libtuner.subprocess.CompletedProcess( args=[res_arg], returncode=0, ) ) - return autotune.TaskResult(result=result, device_id=device_id) + return libtuner.TaskResult(result=result, device_id=device_id) test_input = [ generate_res("str1", 3), @@ -39,7 +39,7 @@ def generate_res(res_arg: str, device_id: int) -> autotune.TaskResult: [generate_res("str5", 7)], ] - actual_output = autotune.group_benchmark_results_by_device_id(test_input) + actual_output = libtuner.group_benchmark_results_by_device_id(test_input) for a, e in zip(actual_output, expect_output): for res1, res2 in zip(a, e): @@ -48,7 +48,7 @@ def generate_res(res_arg: str, device_id: int) -> autotune.TaskResult: def test_sort_candidates_by_first_benchmark_times(): - candidate_trackers = [autotune.CandidateTracker(i) for i in range(5)] + candidate_trackers = [libtuner.CandidateTracker(i) for i in range(5)] candidate_trackers[0].first_benchmark_time = 35 candidate_trackers[1].first_benchmark_time = 2141 candidate_trackers[2].first_benchmark_time = 231 @@ -57,7 +57,7 @@ def test_sort_candidates_by_first_benchmark_times(): test_input = [i for i in range(5)] expect_output = [0, 4, 2, 3, 1] assert ( - autotune.sort_candidates_by_first_benchmark_times( + libtuner.sort_candidates_by_first_benchmark_times( test_input, candidate_trackers ) == expect_output @@ -66,9 +66,9 @@ def test_sort_candidates_by_first_benchmark_times(): def test_find_collisions(): input = [(1, "abc"), (2, "def"), (3, "abc")] - assert autotune.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) + assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) input = [(1, "abc"), (2, "def"), (3, "hig")] - assert autotune.find_collisions(input) == ( + assert libtuner.find_collisions(input) == ( False, [("abc", [1]), ("def", [2]), ("hig", [3])], ) @@ -76,40 +76,40 @@ def test_find_collisions(): def test_collision_handler(): input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")] - assert autotune.collision_handler(input) == (True, [1, 2, 5]) + assert libtuner.collision_handler(input) == (True, [1, 2, 5]) input = [(1, "abc"), (2, "def"), (3, "hig")] - assert autotune.collision_handler(input) == (False, []) + assert libtuner.collision_handler(input) == (False, []) def test_DispatchBenchmarkResult_get(): normal_str = "2 Mean Time: 586.0" - res = autotune.DispatchBenchmarkResult(normal_str) + res = libtuner.DispatchBenchmarkResult(normal_str) assert res.result_str == normal_str assert res.get_tokens() == ["2", "Mean", "Time:", "586.0"] assert res.get_candidate_id() == 2 assert res.get_benchmark_time() == 586.0 incomplete_str = "2 Mean Time:" - res = autotune.DispatchBenchmarkResult(incomplete_str) + res = libtuner.DispatchBenchmarkResult(incomplete_str) assert res.get_tokens() == ["2", "Mean", "Time:"] assert res.get_candidate_id() == 2 assert res.get_benchmark_time() == None incomplete_str = "" - res = autotune.DispatchBenchmarkResult(incomplete_str) + res = libtuner.DispatchBenchmarkResult(incomplete_str) assert res.get_tokens() == [] assert res.get_candidate_id() == None assert res.get_benchmark_time() == None bad_str = 12345 - res = autotune.DispatchBenchmarkResult(bad_str) + res = libtuner.DispatchBenchmarkResult(bad_str) assert res.get_tokens() == [] assert res.get_candidate_id() == None assert res.get_benchmark_time() == None -def test_UnetBenchmarkResult_get(): +def test_ModelBenchmarkResult_get(): normal_str = "Benchmarking: unet_candidate_12.vmfb on device 24\nBM_main/process_time/real_time_median 182 ms 183 ms 5 items_per_second=5.50302/s" - res = autotune.UnetBenchmarkResult(normal_str) + res = libtuner.ModelBenchmarkResult(normal_str) assert res.result_str == normal_str assert res.get_tokens() == [ "Benchmarking:", @@ -125,48 +125,48 @@ def test_UnetBenchmarkResult_get(): "5", "items_per_second=5.50302/s", ] - assert res.get_unet_candidate_path() == "unet_candidate_12.vmfb" + assert res.get_model_candidate_path() == "unet_candidate_12.vmfb" assert res.get_candidate_id() == 12 assert res.get_device_id() == 24 assert res.get_benchmark_time() == 182.0 - incomplete_str = "Benchmarking: unet_baseline.vmfb on device 24\n" - res = autotune.UnetBenchmarkResult(incomplete_str) + incomplete_str = "Benchmarking: baseline.vmfb on device 24\n" + res = libtuner.ModelBenchmarkResult(incomplete_str) assert res.get_tokens() == [ "Benchmarking:", - "unet_baseline.vmfb", + "baseline.vmfb", "on", "device", "24", ] - assert res.get_unet_candidate_path() == "unet_baseline.vmfb" + assert res.get_model_candidate_path() == "baseline.vmfb" assert res.get_candidate_id() == None assert res.get_device_id() == 24 assert res.get_benchmark_time() == None incomplete_str = "" - res = autotune.UnetBenchmarkResult(incomplete_str) + res = libtuner.ModelBenchmarkResult(incomplete_str) assert res.get_tokens() == [] - assert res.get_unet_candidate_path() == None + assert res.get_model_candidate_path() == None assert res.get_candidate_id() == None assert res.get_device_id() == None assert res.get_benchmark_time() == None bad_str = 12345 - res = autotune.UnetBenchmarkResult(bad_str) + res = libtuner.ModelBenchmarkResult(bad_str) assert res.get_tokens() == [] - assert res.get_unet_candidate_path() == None + assert res.get_model_candidate_path() == None assert res.get_candidate_id() == None assert res.get_device_id() == None assert res.get_benchmark_time() == None def test_generate_sample_result(): - res = autotune.DispatchBenchmarkResult() + res = libtuner.DispatchBenchmarkResult() output = res.generate_sample_result(1, 3.14) expected = f"1\tMean Time: 3.1\n" assert output == expected, "DispatchBenchmarkResult generates invalid sample string" - res = autotune.UnetBenchmarkResult() + res = libtuner.ModelBenchmarkResult() output = res.generate_sample_result( 1, "some_dir/tuning_2024_07_24_20_06/unet_candidate_60.vmfb.vmfb", 576.89 ) @@ -174,12 +174,12 @@ def test_generate_sample_result(): assert output == expected, "UnetBenchmarkResult generates invalid sample string" -def test_UnetBenchmarkResult_get_calibrated_result_str(): +def test_ModelBenchmarkResult_get_calibrated_result_str(): baseline_time = 423 res_time = 304 result_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median {float(res_time)} ms 305 ms 5 items_per_second=1.520000/s" change = (res_time - baseline_time) / baseline_time - output_str = autotune.UnetBenchmarkResult(result_str).get_calibrated_result_str( + output_str = libtuner.ModelBenchmarkResult(result_str).get_calibrated_result_str( change ) expect_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median\t {float(res_time)} ms (-28.132%)\t 305 ms\t 5 items_per_second=1.520000/s" @@ -189,7 +189,7 @@ def test_UnetBenchmarkResult_get_calibrated_result_str(): res_time = 218 result_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median {float(res_time)} ms 305 ms 5 items_per_second=1.520000/s" change = (res_time - baseline_time) / baseline_time - output_str = autotune.UnetBenchmarkResult(result_str).get_calibrated_result_str( + output_str = libtuner.ModelBenchmarkResult(result_str).get_calibrated_result_str( change ) expect_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median\t {float(res_time)} ms (+0.000%)\t 305 ms\t 5 items_per_second=1.520000/s" @@ -199,7 +199,7 @@ def test_UnetBenchmarkResult_get_calibrated_result_str(): res_time = 345 result_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median {float(res_time)} ms 305 ms 5 items_per_second=1.520000/s" change = (res_time - baseline_time) / baseline_time - output_str = autotune.UnetBenchmarkResult(result_str).get_calibrated_result_str( + output_str = libtuner.ModelBenchmarkResult(result_str).get_calibrated_result_str( change ) expect_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median\t {float(res_time)} ms (+180.488%)\t 305 ms\t 5 items_per_second=1.520000/s" @@ -207,18 +207,18 @@ def test_UnetBenchmarkResult_get_calibrated_result_str(): def test_parse_dispatch_benchmark_results(): - def generate_res(stdout: str) -> autotune.TaskResult: - result = autotune.subprocess.CompletedProcess( + def generate_res(stdout: str) -> libtuner.TaskResult: + result = libtuner.subprocess.CompletedProcess( args=[""], stdout=stdout, returncode=0, ) - return autotune.TaskResult(result) + return libtuner.TaskResult(result) def generate_parsed_disptach_benchmark_result( time: float, i: int - ) -> autotune.ParsedDisptachBenchmarkResult: - return autotune.ParsedDisptachBenchmarkResult( + ) -> libtuner.ParsedDisptachBenchmarkResult: + return libtuner.ParsedDisptachBenchmarkResult( i, time, path_config.get_candidate_mlir_path(i), @@ -234,22 +234,26 @@ def generate_parsed_disptach_benchmark_result( for i in random_order ] - path_config = autotune.PathConfig() + path_config = libtuner.PathConfig() candidate_trackers = [ - autotune.CandidateTracker(i, mlir_path=path_config.get_candidate_mlir_path(i)) + libtuner.CandidateTracker( + i, dispatch_mlir_path=path_config.get_candidate_mlir_path(i) + ) for i in range(total) ] candidate_trackers_before = [ - autotune.CandidateTracker(i, mlir_path=path_config.get_candidate_mlir_path(i)) + libtuner.CandidateTracker( + i, dispatch_mlir_path=path_config.get_candidate_mlir_path(i) + ) for i in range(total) ] expect_candidate_trackers = [ - autotune.CandidateTracker( + libtuner.CandidateTracker( i, - mlir_path=path_config.get_candidate_mlir_path(i), - mlir_spec_path=path_config.get_candidate_spec_mlir_path(i), + dispatch_mlir_path=path_config.get_candidate_mlir_path(i), + spec_path=path_config.get_candidate_spec_mlir_path(i), ) for i in range(total) ] @@ -264,8 +268,12 @@ def generate_parsed_disptach_benchmark_result( f"{test_list[i][0]} Mean Time: {test_list[i][1]}" for i in random_order ] - parsed_results, dump_list = autotune.parse_dispatch_benchmark_results( - path_config, benchmark_results, candidate_trackers + mock_tuning_client = MagicMock() + mock_tuning_client.get_candidate_spec_filename.side_effect = ( + lambda i: f"{i}_spec.mlir" + ) + parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( + path_config, benchmark_results, candidate_trackers, mock_tuning_client ) assert parsed_results == expect_parsed_results @@ -275,28 +283,28 @@ def generate_parsed_disptach_benchmark_result( def test_parse_grouped_benchmark_results(): - def generate_res(stdout: str, device_id: int) -> autotune.TaskResult: - result = autotune.subprocess.CompletedProcess( + def generate_res(stdout: str, device_id: int) -> libtuner.TaskResult: + result = libtuner.subprocess.CompletedProcess( args=[""], stdout=stdout, returncode=0, ) - return autotune.TaskResult(result=result, device_id=device_id) + return libtuner.TaskResult(result=result, device_id=device_id) def set_tracker( - tracker: autotune.CandidateTracker, - unet_benchmark_time: float, - unet_benchmark_device_id: int, + tracker: libtuner.CandidateTracker, + model_benchmark_time: float, + model_benchmark_device_id: int, baseline_benchmark_time: float, calibrated_benchmark_diff=float, ): - tracker.unet_benchmark_time = unet_benchmark_time - tracker.unet_benchmark_device_id = unet_benchmark_device_id + tracker.model_benchmark_time = model_benchmark_time + tracker.model_benchmark_device_id = model_benchmark_device_id tracker.baseline_benchmark_time = baseline_benchmark_time tracker.calibrated_benchmark_diff = calibrated_benchmark_diff - b1 = "Benchmarking: some_dir/unet_baseline.vmfb on device 0 BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s" - b2 = "Benchmarking: unet_baseline.vmfb on device 1 BM_main/process_time/real_time_median 59.8 ms 15.1 ms 5 items_per_second=16.7114/s" + b1 = "Benchmarking: some_dir/baseline.vmfb on device 0 BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s" + b2 = "Benchmarking: baseline.vmfb on device 1 BM_main/process_time/real_time_median 59.8 ms 15.1 ms 5 items_per_second=16.7114/s" s1 = "Benchmarking: unet_candidate_1.vmfb on device 0 BM_main/process_time/real_time_median 62.4 ms 15.4 ms 5 items_per_second=16.0223/s" s2 = "Benchmarking: some_dir/unet_candidate_2.vmfb on device 1 BM_main/process_time/real_time_median 61.4 ms 11.0 ms 5 items_per_second=16.2958/s" s3 = "Benchmarking: unet_candidate_4.vmfb on device 1 BM_main/process_time/real_time_median 57.4 ms 11.0 ms 5 items_per_second=16.2958/s" @@ -311,22 +319,22 @@ def set_tracker( ], ] - path_config = autotune.PathConfig() + path_config = libtuner.PathConfig() - candidate_trackers = [autotune.CandidateTracker(i) for i in range(5)] + candidate_trackers = [libtuner.CandidateTracker(i) for i in range(5)] - candidate_trackers_before = [autotune.CandidateTracker(i) for i in range(5)] - expect_candidate_trackers = [autotune.CandidateTracker(i) for i in range(5)] + candidate_trackers_before = [libtuner.CandidateTracker(i) for i in range(5)] + expect_candidate_trackers = [libtuner.CandidateTracker(i) for i in range(5)] set_tracker(expect_candidate_trackers[1], 62.4, 0, 60.7, 0.028006589785831888) set_tracker(expect_candidate_trackers[2], 61.4, 1, 59.8, 0.02675585284280939) set_tracker(expect_candidate_trackers[4], 57.4, 1, 59.8, -0.04013377926421403) expect_dump_list = [ - "Benchmarking: some_dir/unet_baseline.vmfb on device 0 " + "Benchmarking: some_dir/baseline.vmfb on device 0 " "BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s", "Benchmarking: unet_candidate_1.vmfb on device 0 " "BM_main/process_time/real_time_median 62.4 ms (+2.801%) 15.4 ms 5 items_per_second=16.0223/s", - "Benchmarking: unet_baseline.vmfb on device 1 " + "Benchmarking: baseline.vmfb on device 1 " "BM_main/process_time/real_time_median 59.8 ms 15.1 ms 5 items_per_second=16.7114/s", "Benchmarking: unet_candidate_4.vmfb on device 1 " "BM_main/process_time/real_time_median 57.4 ms (-4.013%) 11.0 ms 5 items_per_second=16.2958/s", @@ -334,7 +342,7 @@ def set_tracker( "BM_main/process_time/real_time_median 61.4 ms (+2.676%) 11.0 ms 5 items_per_second=16.2958/s", ] - dump_list = autotune.parse_grouped_benchmark_results( + dump_list = libtuner.parse_grouped_benchmark_results( path_config, grouped_benchmark_results, candidate_trackers ) @@ -346,42 +354,42 @@ def set_tracker( candidate_trackers == expect_candidate_trackers ), "candidate_trackers did not change as expected" - b1 = "Benchmarking: unet_baseline.vmfb on device 0" + b1 = "Benchmarking: baseline.vmfb on device 0" s1 = "Benchmarking: unet_candidate_1.vmfb on device 0 BM_main/process_time/real_time_median 62.4 ms 15.4 ms 5 items_per_second=16.0223/s" grouped_benchmark_results = [[generate_res(b1, 0), generate_res(s1, 0)]] - dump_list = autotune.parse_grouped_benchmark_results( + dump_list = libtuner.parse_grouped_benchmark_results( path_config, grouped_benchmark_results, candidate_trackers ) expect_dump_list = [ "Benchmarking: unet_candidate_1.vmfb on device 0 " "BM_main/process_time/real_time_median 62.4 ms 15.4 ms 5 items_per_second=16.0223/s", - "Benchmarking result of unet_baseline.vmfb on deivce 0 is incomplete\n", + "Benchmarking result of baseline.vmfb on deivce 0 is incomplete\n", ] assert dump_list == expect_dump_list, "fail to parse incomplete baselines" - b1 = "Benchmarking: some_dir/unet_baseline.vmfb on device 0 BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s" + b1 = "Benchmarking: some_dir/baseline.vmfb on device 0 BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s" s1 = "Benchmarking: unet_candidate_1.vmfb on device 0" grouped_benchmark_results = [[generate_res(b1, 0), generate_res(s1, 0)]] - candidate_trackers[1].unet_candidate_path = "unet_candidate_1.vmfb" - dump_list = autotune.parse_grouped_benchmark_results( + candidate_trackers[1].model_path = "unet_candidate_1.vmfb" + dump_list = libtuner.parse_grouped_benchmark_results( path_config, grouped_benchmark_results, candidate_trackers ) expect_dump_list = [ - "Benchmarking: some_dir/unet_baseline.vmfb on device 0 " + "Benchmarking: some_dir/baseline.vmfb on device 0 " "BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s", "Benchmarking result of unet_candidate_1.vmfb on deivce 0 is incomplete\n", ] assert dump_list == expect_dump_list, "fail to parse incomplete candidates" - b1 = "Benchmarking: unet_baseline.vmfb on device 0" + b1 = "Benchmarking: baseline.vmfb on device 0" s1 = "Benchmarking: unet_candidate_1.vmfb on device 0" grouped_benchmark_results = [[generate_res(b1, 0), generate_res(s1, 0)]] - candidate_trackers[1].unet_candidate_path = "unet_candidate_1.vmfb" - dump_list = autotune.parse_grouped_benchmark_results( + candidate_trackers[1].model_path = "unet_candidate_1.vmfb" + dump_list = libtuner.parse_grouped_benchmark_results( path_config, grouped_benchmark_results, candidate_trackers ) expect_dump_list = [ - "Benchmarking result of unet_baseline.vmfb on deivce 0 is incomplete\n", + "Benchmarking result of baseline.vmfb on deivce 0 is incomplete\n", "Benchmarking result of unet_candidate_1.vmfb on deivce 0 is incomplete\n", ] assert ( @@ -393,7 +401,7 @@ def test_extract_driver_names(): user_devices = ["hip://0", "local-sync://default", "cuda://default"] expected_output = {"hip", "local-sync", "cuda"} - assert autotune.extract_driver_names(user_devices) == expected_output + assert libtuner.extract_driver_names(user_devices) == expected_output def test_fetch_available_devices_success(): @@ -404,7 +412,7 @@ def test_fetch_available_devices_success(): "cuda": [{"path": "default"}], } - with patch("autotune.ireert.get_driver") as mock_get_driver: + with patch("libtuner.ireert.get_driver") as mock_get_driver: mock_driver = MagicMock() def get_mock_driver(name): @@ -413,7 +421,7 @@ def get_mock_driver(name): mock_get_driver.side_effect = get_mock_driver - actual_output = autotune.fetch_available_devices(drivers) + actual_output = libtuner.fetch_available_devices(drivers) expected_output = ["hip://0", "local-sync://default", "cuda://default"] assert actual_output == expected_output @@ -427,8 +435,8 @@ def test_fetch_available_devices_failure(): "cuda": [{"path": "default"}], } - with patch("autotune.ireert.get_driver") as mock_get_driver: - with patch("autotune.handle_error") as mock_handle_error: + with patch("libtuner.ireert.get_driver") as mock_get_driver: + with patch("libtuner.handle_error") as mock_handle_error: mock_driver = MagicMock() def get_mock_driver(name): @@ -444,7 +452,7 @@ def get_mock_driver(name): mock_get_driver.side_effect = get_mock_driver - actual_output = autotune.fetch_available_devices(drivers) + actual_output = libtuner.fetch_available_devices(drivers) expected_output = ["hip://0", "cuda://default"] assert actual_output == expected_output @@ -460,8 +468,8 @@ def test_parse_devices(): user_devices_str = "hip://0, local-sync://default, cuda://default" expected_output = ["hip://0", "local-sync://default", "cuda://default"] - with patch("autotune.handle_error") as mock_handle_error: - actual_output = autotune.parse_devices(user_devices_str) + with patch("libtuner.handle_error") as mock_handle_error: + actual_output = libtuner.parse_devices(user_devices_str) assert actual_output == expected_output mock_handle_error.assert_not_called() @@ -476,8 +484,8 @@ def test_parse_devices_with_invalid_input(): "cuda://default", ] - with patch("autotune.handle_error") as mock_handle_error: - actual_output = autotune.parse_devices(user_devices_str) + with patch("libtuner.handle_error") as mock_handle_error: + actual_output = libtuner.parse_devices(user_devices_str) assert actual_output == expected_output mock_handle_error.assert_called_once_with( @@ -491,13 +499,13 @@ def test_validate_devices(): user_devices = ["hip://0", "local-sync://default"] user_drivers = {"hip", "local-sync"} - with patch("autotune.extract_driver_names", return_value=user_drivers): + with patch("libtuner.extract_driver_names", return_value=user_drivers): with patch( - "autotune.fetch_available_devices", + "libtuner.fetch_available_devices", return_value=["hip://0", "local-sync://default"], ): - with patch("autotune.handle_error") as mock_handle_error: - autotune.validate_devices(user_devices) + with patch("libtuner.handle_error") as mock_handle_error: + libtuner.validate_devices(user_devices) assert all( call[1]["condition"] is False for call in mock_handle_error.call_args_list @@ -508,13 +516,13 @@ def test_validate_devices_with_invalid_device(): user_devices = ["hip://0", "local-sync://default", "cuda://default"] user_drivers = {"hip", "local-sync", "cuda"} - with patch("autotune.extract_driver_names", return_value=user_drivers): + with patch("libtuner.extract_driver_names", return_value=user_drivers): with patch( - "autotune.fetch_available_devices", + "libtuner.fetch_available_devices", return_value=["hip://0", "local-sync://default"], ): - with patch("autotune.handle_error") as mock_handle_error: - autotune.validate_devices(user_devices) + with patch("libtuner.handle_error") as mock_handle_error: + libtuner.validate_devices(user_devices) expected_call = call( condition=True, msg=f"Invalid device specified: cuda://default\nFetched available devices: ['hip://0', 'local-sync://default']",