Skip to content

Commit

Permalink
Update Tests Fetcher (#5950)
Browse files Browse the repository at this point in the history
* update setup and deps table

* update

* update

* update

* up

* up

* update

* up

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* quality fix

* fix failure reporting

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
DN6 and patrickvonplaten authored Dec 4, 2023
1 parent 8a812e4 commit b217292
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 40 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/pr_test_fetcher.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ jobs:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
fetch-depth: 0
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .
python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
echo $(git --version)
- name: Fetch Tests
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
Expand Down Expand Up @@ -110,7 +111,7 @@ jobs:
continue-on-error: true
run: |
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
"ruff>=0.1.5,<=0.2",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
"scipy",
"onnx",
"regex!=2019.12.17",
Expand Down Expand Up @@ -206,6 +207,7 @@ def run(self):
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list(
"compel",
"GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ruff": "ruff>=0.1.5,<=0.2",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"regex": "regex!=2019.12.17",
Expand Down
95 changes: 58 additions & 37 deletions utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,23 @@
PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers"
PATH_TO_TESTS = PATH_TO_REPO / "tests"

# List here the pipelines to always test.
# Ignore fixtures in tests folder
# Ignore lora since they are always tested
MODULES_TO_IGNORE = ["fixtures", "lora"]

IMPORTANT_PIPELINES = [
"controlnet",
"stable_diffusion",
"stable_diffusion_2",
"stable_diffusion_xl",
"stable_video_diffusion",
"deepfloyd_if",
"kandinsky",
"kandinsky2_2",
"text_to_video_synthesis",
"wuerstchen",
]

# Ignore fixtures in tests folder
# Ignore lora since they are always tested
MODULES_TO_IGNORE = ["fixtures", "lora"]


@contextmanager
def checkout_commit(repo: Repo, commit_id: str):
Expand Down Expand Up @@ -289,10 +289,13 @@ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
repo = Repo(PATH_TO_REPO)

if not diff_with_last_commit:
print(f"main is at {repo.refs.main.commit}")
# Need to fetch refs for main using remotes when running with github actions.
upstream_main = repo.remotes.origin.refs.main

print(f"main is at {upstream_main.commit}")
print(f"Current head is at {repo.head.commit}")

branching_commits = repo.merge_base(repo.refs.main, repo.head)
branching_commits = repo.merge_base(upstream_main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
return get_diff(repo, repo.head.commit, branching_commits)
Expand Down Expand Up @@ -415,10 +418,11 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:

test_files_to_run = [] # noqa
if not diff_with_last_commit:
print(f"main is at {repo.refs.main.commit}")
upstream_main = repo.remotes.origin.refs.main
print(f"main is at {upstream_main.commit}")
print(f"Current head is at {repo.head.commit}")

branching_commits = repo.merge_base(repo.refs.main, repo.head)
branching_commits = repo.merge_base(upstream_main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits)
Expand All @@ -432,7 +436,7 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
all_test_files_to_run = get_all_doctest_files()

# Add to the test files to run any removed entry from "utils/not_doctested.txt".
new_test_files = get_new_doctest_files(repo, repo.head.commit, repo.refs.main.commit)
new_test_files = get_new_doctest_files(repo, repo.head.commit, upstream_main.commit)
test_files_to_run = list(set(test_files_to_run + new_test_files))

# Do not run slow doctest tests on CircleCI
Expand Down Expand Up @@ -774,18 +778,16 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
return reverse_map


def create_module_to_test_map(
reverse_map: Dict[str, List[str]] = None, filter_models: bool = False
) -> Dict[str, List[str]]:
def create_module_to_test_map(reverse_map: Dict[str, List[str]] = None) -> Dict[str, List[str]]:
"""
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
Args:
reverse_map (`Dict[str, List[str]]`, *optional*):
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
that function if not provided.
filter_models (`bool`, *optional*, defaults to `False`):
Whether or not to filter model tests to only include core models if a file impacts a lot of models.
filter_pipelines (`bool`, *optional*, defaults to `False`):
Whether or not to filter pipeline tests to only include core pipelines if a file impacts a lot of models.
Returns:
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
Expand All @@ -804,21 +806,7 @@ def is_test(fname):
# Build the test map
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}

if not filter_models:
return test_map

# Now we deal with the filtering if `filter_models` is True.
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))

def has_many_models(tests):
# We filter to core models when a given file impacts more than half the model tests.
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
return len(model_tests) > num_model_tests // 2

def filter_tests(tests):
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_PIPELINES]

return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
return test_map


def check_imports_all_exist():
Expand All @@ -844,7 +832,39 @@ def _print_list(l) -> str:
return "\n".join([f"- {f}" for f in l])


def create_json_map(test_files_to_run: List[str], json_output_file: str):
def update_test_map_with_core_pipelines(json_output_file: str):
print(f"\n### ADD CORE PIPELINE TESTS ###\n{_print_list(IMPORTANT_PIPELINES)}")
with open(json_output_file, "rb") as fp:
test_map = json.load(fp)

# Add core pipelines as their own test group
test_map["core_pipelines"] = " ".join(
sorted([str(PATH_TO_TESTS / f"pipelines/{pipe}") for pipe in IMPORTANT_PIPELINES])
)

# If there are no existing pipeline tests save the map
if "pipelines" not in test_map:
with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)

pipeline_tests = test_map.pop("pipelines")
pipeline_tests = pipeline_tests.split(" ")

# Remove core pipeline tests from the fetched pipeline tests
updated_pipeline_tests = []
for pipe in pipeline_tests:
if pipe == "tests/pipelines" or Path(pipe).parts[2] in IMPORTANT_PIPELINES:
continue
updated_pipeline_tests.append(pipe)

if len(updated_pipeline_tests) > 0:
test_map["pipelines"] = " ".join(sorted(updated_pipeline_tests))

with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)


def create_json_map(test_files_to_run: List[str], json_output_file: Optional[str] = None):
"""
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
Expand Down Expand Up @@ -881,14 +901,14 @@ def create_json_map(test_files_to_run: List[str], json_output_file: str):
# sort the keys & values
keys = sorted(test_map.keys())
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}

with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)


def infer_tests_to_run(
output_file: str,
diff_with_last_commit: bool = False,
filter_models: bool = True,
json_output_file: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -929,8 +949,9 @@ def infer_tests_to_run(
# Grab the corresponding test files:
if any(x in modified_files for x in ["setup.py"]):
test_files_to_run = ["tests", "examples"]

# in order to trigger pipeline tests even if no code change at all
elif "tests/utils/tiny_model_summary.json" in modified_files:
if "tests/utils/tiny_model_summary.json" in modified_files:
test_files_to_run = ["tests"]
any(f.split(os.path.sep)[0] == "utils" for f in modified_files)
else:
Expand All @@ -939,7 +960,7 @@ def infer_tests_to_run(
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
]
# Then we grab the corresponding test files.
test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models)
test_map = create_module_to_test_map(reverse_map=reverse_map)
for f in modified_files:
if f in test_map:
test_files_to_run.extend(test_map[f])
Expand Down Expand Up @@ -1064,8 +1085,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
args = parser.parse_args()
if args.print_dependencies_of is not None:
print_tree_deps_of(args.print_dependencies_of)
elif args.filter_tests:
filter_tests(args.output_file, ["pipelines", "repo_utils"])
else:
repo = Repo(PATH_TO_REPO)
commit_message = repo.head.commit.message
Expand All @@ -1089,9 +1108,10 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
args.output_file,
diff_with_last_commit=diff_with_last_commit,
json_output_file=args.json_output_file,
filter_models=not commit_flags["no_filter"],
)
filter_tests(args.output_file, ["repo_utils"])
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)

except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
commit_flags["test_all"] = True
Expand All @@ -1105,3 +1125,4 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:

test_files_to_run = get_all_tests()
create_json_map(test_files_to_run, args.json_output_file)
update_test_map_with_core_pipelines(json_output_file=args.json_output_file)

0 comments on commit b217292

Please sign in to comment.