From b8ca7be177658a90ca4de74fb683d762a499fc0d Mon Sep 17 00:00:00 2001 From: bareeva Date: Wed, 3 Apr 2024 11:01:55 +0200 Subject: [PATCH] housekeeping --- .github/workflows/tests.yml | 6 --- .github/workflows/type-lint.yml | 4 +- .gitignore | 54 ++++++++++++++++++++--- Makefile | 1 - pyproject.toml | 13 +++++- src/utils/datasets/group_label_dataset.py | 18 ++++---- src/utils/datasets/mark_dataset.py | 24 +++++----- tests/utils/conftest.py | 2 +- tests/utils/test_corrupt_label_dataset.py | 11 ++++- tox.ini | 2 +- 10 files changed, 94 insertions(+), 41 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 92674c39..597ba1a3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,9 +21,3 @@ jobs: - name: Measure coverage. run: tox run -e coverage - - - name: Upload - uses: codecov/codecov-action@v3 - with: - token: ${{ secrets.CODECOV_TOKEN }} - files: ./coverage.xml, coverage.xml diff --git a/.github/workflows/type-lint.yml b/.github/workflows/type-lint.yml index 36c4c609..f0c9ad08 100644 --- a/.github/workflows/type-lint.yml +++ b/.github/workflows/type-lint.yml @@ -1,4 +1,4 @@ -# .github/workflows/tests.yml +# .github/workflows/type-lint.yml name: Type-lint on: push jobs: @@ -7,7 +7,7 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Setup python 3.10 + - name: Setup python 3.11 uses: actions/setup-python@v4 with: cache: 'pip' diff --git a/.gitignore b/.gitignore index dd19bf4a..848fb101 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,54 @@ -*/*.egg-info/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ *.egg-info/ +.installed.cfg +*.egg +MANIFEST /.idea/ /.tox/ -/.coverage -.pytest_cache -*.DS_Store -__pycache__/ +# Jupyter Notebook +.ipynb_checkpoints + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# IPython +profile_default/ +ipython_config.py + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ diff --git a/Makefile b/Makefile index 435fbf6c..1dcdbaa9 100644 --- a/Makefile +++ b/Makefile @@ -5,5 +5,4 @@ SHELL = /bin/bash .PHONY: style style: black . - flake8 python3 -m isort . diff --git a/pyproject.toml b/pyproject.toml index f1466258..b5eaa535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,12 @@ line_length = 79 multi_line_output = 3 include_trailing_comma = true +[tool.mypy] +python_version = "3.11" +warn_return_any = false +warn_unused_configs = true +ignore_errors = true # TODO: change this + # Black formatting [tool.black] line-length = 120 @@ -43,11 +49,14 @@ testpaths = ["tests"] python_files = "test_*.py" [project.optional-dependencies] -tests = [ +dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh "coverage>=7.2.3", "flake8>=6.0.0", "pytest<=7.4.4", "pytest-cov>=4.0.0", "pytest-mock==3.10.0", - "pytest_xdist", + "pre-commit>=3.2.0", + "mypy>=1.8.0", + "black[d]>=23.0.0", + "isort>=5.0.0", ] diff --git a/src/utils/datasets/group_label_dataset.py b/src/utils/datasets/group_label_dataset.py index 18a2bc85..157e6fa7 100644 --- a/src/utils/datasets/group_label_dataset.py +++ b/src/utils/datasets/group_label_dataset.py @@ -4,15 +4,6 @@ class GroupLabelDataset(Dataset): class_group_by2 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] - @staticmethod - def check_class_groups(groups): - vals = [[] for _ in range(10)] - for g, group in enumerate(groups): - for i in group: - vals[i].append(g) - for v in vals: - assert len(v) == 1 # Check that this is the first time i is encountered - def __init__(self, dataset, class_groups=None): self.dataset = dataset self.class_labels = [i for i in range(len(class_groups))] @@ -33,3 +24,12 @@ def __getitem__(self, item): def __len__(self): return len(self.dataset) + + @staticmethod + def check_class_groups(groups): + vals = [[] for _ in range(10)] + for g, group in enumerate(groups): + for i in group: + vals[i].append(g) + for v in vals: + assert len(v) == 1 # Check that this is the first time i is encountered diff --git a/src/utils/datasets/mark_dataset.py b/src/utils/datasets/mark_dataset.py index 3183c974..bad73895 100644 --- a/src/utils/datasets/mark_dataset.py +++ b/src/utils/datasets/mark_dataset.py @@ -5,18 +5,6 @@ class MarkDataset(Dataset): - def get_mark_sample_ids(self): - indices = [] - cls = self.cls_to_mark - prob = self.mark_prob - for i in range(len(self.dataset)): - x, y = self.dataset[i] - if y == cls: - rnd = torch.rand(1) - if rnd < prob: - indices.append(i) - return torch.tensor(indices, dtype=torch.int) - def __init__(self, dataset, p=0.3, cls_to_mark=2, only_train=False): super().__init__() self.class_labels = dataset.class_labels @@ -54,6 +42,18 @@ def __getitem__(self, item): else: return x, y + def get_mark_sample_ids(self): + indices = [] + cls = self.cls_to_mark + prob = self.mark_prob + for i in range(len(self.dataset)): + x, y = self.dataset[i] + if y == cls: + rnd = torch.rand(1) + if rnd < prob: + indices.append(i) + return torch.tensor(indices, dtype=torch.int) + def mark_image_contour(self, x): x = self.dataset.inverse_transform(x) mask = torch.zeros_like(x[0]) diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 1cc989bc..682fd2d7 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -3,7 +3,7 @@ @pytest.fixture() -def dataset(): +def load_dataset(): x = torch.stack([torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)]) y = torch.tensor([0, 1, 0]).long() return torch.utils.data.TensorDataset(x, y) diff --git a/tests/utils/test_corrupt_label_dataset.py b/tests/utils/test_corrupt_label_dataset.py index 0a87d1aa..710683ba 100644 --- a/tests/utils/test_corrupt_label_dataset.py +++ b/tests/utils/test_corrupt_label_dataset.py @@ -4,6 +4,13 @@ @pytest.mark.utils -def test_corrupt_label_dataset(dataset): +@pytest.mark.parametrize( + "dataset, n_expected", + [ + ("load_dataset", 2), + ], +) +def test_corrupt_label_dataset(dataset, n_expected, request): + dataset = request.getfixturevalue(dataset) # cl_dataset = CorruptLabelDataset(dataset) - assert 2 == 2 + assert 2 == n_expected diff --git a/tox.ini b/tox.ini index 03f132f9..87e749c9 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,6 @@ description = Run type checking base_python = py311 deps = {[testenv]deps} - mypy==0.982 + mypy==1.9.0 commands = python3 -m mypy src