diff --git a/scripts/RunRiscVArchTest.py b/scripts/RunRiscVArchTest.py index c9e3e65..2fac93c 100755 --- a/scripts/RunRiscVArchTest.py +++ b/scripts/RunRiscVArchTest.py @@ -37,11 +37,12 @@ def main(): print("WARNING: Skipping all non-G extension tests") tests = [test for test in tests if any(ext+"-" in test for ext in SUPPORTED_EXTENSIONS)] - # Run RV64 tests - passing_tests = [] - failing_tests = [] - print("Running " + str(len(tests)) + " arch tests...") - for test in tests: + import multiprocessing + passing_tests = multiprocessing.Queue() + failing_tests = multiprocessing.Queue() + + # Function to run a single test and append to the appropriate queue + def run_test(test, passing_tests, failing_tests): testname = os.path.basename(test) logname = testname + ".log" instlogname = testname + ".instlog" @@ -53,25 +54,44 @@ def main(): if result.returncode == 0: test_passed = True except subprocess.TimeoutExpired: - continue + return if test_passed: - passing_tests.append(testname) + passing_tests.put(testname) # Remove log files if test passed os.remove(logname) os.remove(instlogname) else: - failing_tests.append(testname) + failing_tests.put(testname) + + # Function to run tests using processes + def run_tests_in_parallel(tests, passing_tests, failing_tests): + print("Running " + str(len(tests)) + " arch tests...") + processes = [] + + # Create a process for each test command + for test in tests: + process = multiprocessing.Process(target=run_test, args=(test, passing_tests, failing_tests)) + process.start() + processes.append(process) + # Wait for all processes to finish + for process in processes: + process.join() + run_tests_in_parallel(tests, passing_tests, failing_tests) + + num_passed = 0 print("PASSED:") - for test in passing_tests: - print("\t" + test) + while not passing_tests.empty(): + print("\t" + passing_tests.get()) + num_passed += 1 + print("FAILED:") - for test in failing_tests: - print("\t" + test) + while not failing_tests.empty(): + print("\t" + failing_tests.get()) - print("\nPASS RATE: " + str(len(passing_tests)) + "/" + str(len(tests))) + print("\nPASS RATE: " + str(num_passed) + "/" + str(len(tests))) if __name__ == "__main__":