Skip to content

Commit

Permalink
#8380: re-arranged ttnn sweeps structure
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 11, 2024
1 parent 334fc76 commit 3da50e2
Show file tree
Hide file tree
Showing 148 changed files with 176 additions and 183 deletions.
21 changes: 6 additions & 15 deletions tests/ttnn/sweep_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,29 @@ python tests/ttnn/sweep_tests/run_sweeps.py

## Running a single sweep
```
python tests/ttnn/sweep_tests/run_sweeps.py --include add,matmul
```

## Running a single test
```
python tests/ttnn/sweep_tests/run_single_test.py --test-name add --index 0
python tests/ttnn/sweep_tests/run_sweeps.py --include add.py,matmul.py
```

## Printing report of all sweeps
```
python tests/ttnn/sweep_tests/print_report.py [--detailed]
```

## Debugging sweeps
```
python tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py [--exclude add,linear] [--stepwise]
```

## Using Pytest to run sweeps all the sweeps for one operation file
```
pytest <full-path-to-tt-metal>/tt-metal/tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_<operation>
Example for matmul: pytest tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_matmul
pytest <full-path-to-tt-metal>/tt-metal/tests/ttnn/sweep_tests/test_sweeps.py::test_<operation>
Example for matmul: pytest tests/ttnn/sweep_tests/test_sweeps.py::test_matmul
```

## Using Pytest to run a single sweep test by the index
```
pytest <full-path-to-tt-metal>/tt-metal/tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_<operation>[<operation>.py-<index-of-test-instance>]
Example for matmul: pytest tests/ttnn/sweep_tests/test_all_sweep_tests.py::test_matmul[matmul.py-0]
pytest <full-path-to-tt-metal>/tt-metal/tests/ttnn/sweep_tests/test_sweeps.py::test_<operation>[<operation>.py-<index-of-test-instance>]
Example for matmul: TODO(arakhmati)
```

## Adding a new sweep test
In `tests/ttnn/sweep_tests/sweeps` add a new file `<new_file>.py`.
In `tests/ttnn/sweep_tests/sweeps` add a new file `<new_file>.py`. (You can new folders as well)

The file must contain:
- `parameters` dictionary from a variable to the list of values to sweep
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/sweep_tests/print_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import argparse

from tests.ttnn.sweep_tests.sweep import print_report
from sweeps import print_report


def main():
Expand Down
43 changes: 0 additions & 43 deletions tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py

This file was deleted.

45 changes: 0 additions & 45 deletions tests/ttnn/sweep_tests/run_single_test.py

This file was deleted.

6 changes: 3 additions & 3 deletions tests/ttnn/sweep_tests/run_sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ttnn


from tests.ttnn.sweep_tests.sweep import run_sweeps, print_report
from sweeps import run_sweeps, print_report


def convert_string_to_list(string):
Expand All @@ -28,9 +28,9 @@ def main():
include = convert_string_to_list(include)

device = ttnn.open_device(device_id=0)
run_sweeps(device=device, include=include)
table_names = run_sweeps(device=device, include=include)
ttnn.close_device(device)
print_report(include=include)
print_report(table_names=table_names)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/ttnn/sweep_tests/sweeps/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**/db.sqlite
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@

# SPDX-License-Identifier: Apache-2.0

import datetime
from importlib.machinery import SourceFileLoader
import pathlib
import pickle
import zlib

from loguru import logger
import pandas as pd
import sqlite3
from tqdm import tqdm

SWEEPS_DIR = pathlib.Path(__file__).parent
SWEEP_SOURCES_DIR = SWEEPS_DIR / "sweeps"
SWEEP_RESULTS_DIR = SWEEPS_DIR / "results"
DATABASE_FILE_NAME = SWEEP_RESULTS_DIR / "db.sqlite"

if not SWEEP_RESULTS_DIR.exists():
SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True)


def get_sweep_name(file_name):
return str(pathlib.Path(file_name).relative_to(SWEEP_SOURCES_DIR))


def permutations(parameters):
Expand Down Expand Up @@ -91,11 +103,11 @@ def _run_single_test(run, skip, is_expected_to_fail, permutation, *, device):
return status, message


def run_single_test(test_name, index, *, device):
file_name = (SWEEP_SOURCES_DIR / test_name).with_suffix(".py")
def run_single_test(file_name, index, *, device):
logger.info(f"Running {file_name}")

sweep_module = SourceFileLoader(f"sweep_module_{file_name.stem}", str(file_name)).load_module()
sweep_name = get_sweep_name(file_name)
sweep_module = SourceFileLoader(f"sweep_module_{sweep_name}", str(file_name)).load_module()
permutation = list(permutations(sweep_module.parameters))[index]

pretty_printed_parameters = ",\n".join(
Expand All @@ -107,82 +119,69 @@ def run_single_test(test_name, index, *, device):
)


def run_sweep(sweep_file_name, *, device):
sweep_name = pathlib.Path(sweep_file_name).stem
sweep_module = SourceFileLoader(f"sweep_module_{sweep_name}", str(sweep_file_name)).load_module()
def run_sweep(file_name, *, device):
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
sweep_name = get_sweep_name(file_name)
sweep_module = SourceFileLoader(f"sweep_module_{sweep_name}", str(file_name)).load_module()

parameter_names = get_parameter_names(sweep_module.parameters)
column_names = ["status", "exception"] + parameter_names
column_names = ["sweep_name", "timestamp", "status", "exception"] + parameter_names

rows = []
for permutation in tqdm(list(permutations(sweep_module.parameters))):
status, message = _run_single_test(
sweep_module.run, sweep_module.skip, sweep_module.is_expected_to_fail, permutation, device=device
)
rows.append([status, message] + list(get_parameter_values(parameter_names, permutation)))

SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
file_name = (SWEEP_RESULTS_DIR / sweep_name).with_suffix(".csv")
rows.append(
[sweep_name, current_datetime, status, message] + list(get_parameter_values(parameter_names, permutation))
)

df = pd.DataFrame(rows, columns=column_names)
df.to_csv(file_name)
connection = sqlite3.connect(DATABASE_FILE_NAME)
cursor = connection.cursor()

logger.info(f"Saved sweep results to {file_name}")
table_hash = zlib.adler32(pickle.dumps(f"{sweep_name}_{current_datetime}"))
table_name = f"table_{table_hash}"

def column_names_to_string(column_names):
def name_to_string(name):
if name == "timestamp":
return "timestamp TIMESTAMP"
else:
return f"{name} TEXT"

def run_sweeps(*, device, include):
logger.info(f"Deleting old sweep results in {SWEEP_RESULTS_DIR}")
if SWEEP_RESULTS_DIR.exists():
for file_name in SWEEP_RESULTS_DIR.glob("*.csv"):
file_name.unlink()

for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")):
name = file_name.stem
if include and name not in include:
continue
logger.info(f"Running {file_name}")
run_sweep(file_name, device=device)
column_names = [name_to_string(name) for name in column_names]
return ", ".join(column_names)

command = f"CREATE TABLE IF NOT EXISTS {table_name} ({column_names_to_string(column_names)})"
cursor.execute(command)

def run_failed_and_crashed_tests(*, device, stepwise, include, exclude):
keep_running = True
for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
test_name = file_name.stem
for row in rows:
row = [str(value) for value in row]
row_placeholders = ", ".join(["?"] * len(column_names))
command = f"INSERT INTO {table_name} VALUES ({row_placeholders})"
cursor.execute(command, row)
connection.commit()
connection.close()

if include and test_name not in include:
continue
SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"Saved sweep results to table {table_name} in {DATABASE_FILE_NAME}")

if exclude and test_name in exclude:
continue
return table_name

if not keep_running:
break

df = pd.read_csv(file_name)
failed = (df["status"] == "failed").sum()
crashed = (df["status"] == "crashed").sum()
if failed == 0 and crashed == 0:
def run_sweeps(*, device, include):
table_names = []
for file_name in sorted(SWEEP_SOURCES_DIR.glob("**/*.py")):
sweep_name = get_sweep_name(file_name)
if include and sweep_name not in include:
continue

for index, row in enumerate(df.itertuples()):
if row.status not in {"failed", "crashed"}:
continue

status, message = run_single_test(file_name.stem, index, device=device)
logger.info(status)
if status in {"failed", "crashed"}:
logger.error(f"{message}")
if stepwise:
keep_running = False
break

df.at[index, "status"] = status
df.at[index, "message"] = message

df.to_csv(file_name)
logger.info(f"Running {file_name}")
table_name = run_sweep(file_name, device=device)
table_names.append(table_name)
return table_names


def print_summary(*, include):
def print_summary(*, table_names):
stats_df = pd.DataFrame(columns=["name", "passed", "failed", "crashed", "skipped", "is_expected_to_fail"])

def add_row(df, name):
Expand All @@ -191,12 +190,15 @@ def add_row(df, name):
df.reset_index(inplace=True, drop=True)
return df

for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
name = file_name.stem
if include and name not in include:
connection = sqlite3.connect(DATABASE_FILE_NAME)
cursor = connection.cursor()

for (table_name,) in cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall():
if table_names is not None and table_name not in table_names:
continue
df = pd.read_csv(file_name)
stats_df = add_row(stats_df, file_name.stem)
df = pd.read_sql_query(f"SELECT * FROM {table_name}", connection)
sweep_name = df["sweep_name"].iloc[0]
stats_df = add_row(stats_df, sweep_name)
for status in stats_df.columns[1:]:
stats_df.at[len(stats_df) - 1, status] = (df["status"] == status).sum()

Expand All @@ -206,25 +208,27 @@ def add_row(df, name):
print(stats_df)


def print_detailed_report(*, include):
for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
name = file_name.stem
if include and name not in include:
def print_detailed_report(*, table_names):
connection = sqlite3.connect(DATABASE_FILE_NAME)
cursor = connection.cursor()

for (table_name,) in cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall():
if table_names is not None and table_name not in table_names:
continue
df = pd.read_csv(file_name)
df = pd.read_sql_query(f"SELECT * FROM {table_name}", connection)
for index, row in enumerate(df.itertuples()):
if row.status in {"failed", "crashed"}:
print(f"{name}@{index}: {row.status}")
print(f"{table_name}@{index}: {row.status}")
print(f"\t{row.exception}")
elif row.status == "skipped":
print(f"{name}@{index}: {row.status}")
print(f"{table_name}@{index}: {row.status}")
else:
print(f"{name}@{index}: {row.status}")
print(f"{table_name}@{index}: {row.status}")
print()


def print_report(*, include=None, detailed=False):
def print_report(*, table_names=None, detailed=False):
if detailed:
print_detailed_report(include=include)
print_detailed_report(table_names=table_names)
else:
print_summary(include=include)
print_summary(table_names=table_names)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 3da50e2

Please sign in to comment.