Skip to content

Commit

Permalink
Make scripts handle deleted files
Browse files Browse the repository at this point in the history
Supress printing in script tests
  • Loading branch information
IanCa committed Jun 14, 2024
1 parent cfcd205 commit 5e108a6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 21 deletions.
3 changes: 3 additions & 0 deletions hed/scripts/script_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def sort_base_schemas(filenames):
"""
schema_files = defaultdict(set)
for file_path in filenames:
if not os.path.exists(file_path):
print(f"Ignoring deleted file {file_path}.")
continue
basename, extension = os.path.splitext(file_path)
if extension == ".xml" or extension == ".mediawiki":
schema_files[basename].add(extension)
Expand Down
15 changes: 10 additions & 5 deletions tests/scripts/test_convert_and_update_schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import unittest
import os
import shutil
import copy
import os
from hed import load_schema, load_schema_version
from hed.schema import HedSectionKey, HedKey
from hed.scripts.script_util import add_extension
from hed.scripts.convert_and_update_schema import convert_and_update
import contextlib


class TestConvertAndUpdate(unittest.TestCase):
Expand All @@ -24,7 +25,8 @@ def test_schema_conversion_and_update(self):

# Assume filenames updated includes just the original schema file for simplicity
filenames = [original_name]
result = convert_and_update(filenames, set_ids=False)
with contextlib.redirect_stdout(None):
result = convert_and_update(filenames, set_ids=False)

# Verify no error from convert_and_update and the correct schema version was saved
self.assertEqual(result, 0)
Expand All @@ -41,7 +43,8 @@ def test_schema_conversion_and_update(self):
schema.save_as_dataframes(tsv_filename)

filenames = [os.path.join(tsv_filename, "test_schema_Tag.tsv")]
result = convert_and_update(filenames, set_ids=False)
with contextlib.redirect_stdout(None):
result = convert_and_update(filenames, set_ids=False)

# Verify no error from convert_and_update and the correct schema version was saved
self.assertEqual(result, 0)
Expand All @@ -68,14 +71,16 @@ def test_schema_adding_tag(self):

# Assume filenames updated includes just the original schema file for simplicity
filenames = [add_extension(basename, ".mediawiki")]
result = convert_and_update(filenames, set_ids=False)
with contextlib.redirect_stdout(None):
result = convert_and_update(filenames, set_ids=False)
self.assertEqual(result, 0)

schema_reloaded = load_schema(add_extension(basename, ".xml"))

self.assertEqual(schema_reloaded, schema_edited)

result = convert_and_update(filenames, set_ids=True)
with contextlib.redirect_stdout(None):
result = convert_and_update(filenames, set_ids=True)
self.assertEqual(result, 0)

schema_reloaded = load_schema(add_extension(basename, ".xml"))
Expand Down
42 changes: 26 additions & 16 deletions tests/scripts/test_script_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
from hed import load_schema_version
from hed.scripts.script_util import add_extension, sort_base_schemas, validate_all_schema_formats, validate_schema
import contextlib


class TestAddExtension(unittest.TestCase):
Expand Down Expand Up @@ -40,7 +41,8 @@ def test_mixed_file_types(self):
"test_schema": {".mediawiki", ".tsv"},
"other_schema": {".xml"}
}
result = sort_base_schemas(filenames)
with contextlib.redirect_stdout(None):
result = sort_base_schemas(filenames)
self.assertEqual(dict(result), expected)

def test_tsv_in_correct_subfolder(self):
Expand All @@ -52,7 +54,8 @@ def test_tsv_in_correct_subfolder(self):
expected = {
"test_schema": {".tsv"}
}
result = sort_base_schemas(filenames)
with contextlib.redirect_stdout(None):
result = sort_base_schemas(filenames)
self.assertEqual(dict(result), expected)

def test_tsv_in_correct_subfolder2(self):
Expand All @@ -64,7 +67,8 @@ def test_tsv_in_correct_subfolder2(self):
expected = {
os.path.normpath("prerelease/test_schema"): {".tsv"}
}
result = sort_base_schemas(filenames)
with contextlib.redirect_stdout(None):
result = sort_base_schemas(filenames)
self.assertEqual(dict(result), expected)

def test_ignored_files(self):
Expand All @@ -75,13 +79,15 @@ def test_ignored_files(self):
expected = {
"test_schema": {".mediawiki"}
}
result = sort_base_schemas(filenames)
with contextlib.redirect_stdout(None):
result = sort_base_schemas(filenames)
self.assertEqual(dict(result), expected)

def test_empty_input(self):
filenames = []
expected = {}
result = sort_base_schemas(filenames)
with contextlib.redirect_stdout(None):
result = sort_base_schemas(filenames)
self.assertEqual(dict(result), expected)


Expand All @@ -100,19 +106,22 @@ def test_error_no_error(self):
schema = load_schema_version("8.3.0")
schema.save_as_xml(os.path.join(self.base_path, self.basename + ".xml"))
schema.save_as_dataframes(os.path.join(self.base_path, "hedtsv", self.basename))
issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename))
with contextlib.redirect_stdout(None):
issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename))
self.assertTrue(issues)
self.assertIn("Error loading schema", issues[0])

schema.save_as_mediawiki(os.path.join(self.base_path, self.basename + ".mediawiki"))

self.assertEqual(validate_all_schema_formats(os.path.join(self.base_path, self.basename)), [])
with contextlib.redirect_stdout(None):
self.assertEqual(validate_all_schema_formats(os.path.join(self.base_path, self.basename)), [])

schema_incorrect = load_schema_version("8.2.0")
schema_incorrect.save_as_dataframes(os.path.join(self.base_path, "hedtsv", self.basename))

# Validate and expect errors
issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename))
with contextlib.redirect_stdout(None):
issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename))
self.assertTrue(issues)
self.assertIn("Multiple schemas of type", issues[0])

Expand All @@ -125,11 +134,12 @@ def tearDownClass(cls):
class TestValidateSchema(unittest.TestCase):
def test_load_invalid_extension(self):
# Verify capital letters fail validation
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.MEDIAWIKI")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.Mediawiki")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.XML")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.Xml")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.TSV")[0])
self.assertNotIn("Only fully lowercase extensions ", validate_schema("does_not_matter.tsv")[0])
self.assertNotIn("Only fully lowercase extensions ", validate_schema("does_not_matter.xml")[0])
self.assertNotIn("Only fully lowercase extensions ", validate_schema("does_not_matter.mediawiki")[0])
with contextlib.redirect_stdout(None):
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.MEDIAWIKI")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.Mediawiki")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.XML")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.Xml")[0])
self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.TSV")[0])
self.assertNotIn("Only fully lowercase extensions ", validate_schema("does_not_matter.tsv")[0])
self.assertNotIn("Only fully lowercase extensions ", validate_schema("does_not_matter.xml")[0])
self.assertNotIn("Only fully lowercase extensions ", validate_schema("does_not_matter.mediawiki")[0])

0 comments on commit 5e108a6

Please sign in to comment.