Skip to content

Commit

Permalink
housekeeping
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Apr 3, 2024
1 parent ea18804 commit b8ca7be
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 41 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,3 @@ jobs:

- name: Measure coverage.
run: tox run -e coverage

- name: Upload
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml, coverage.xml
4 changes: 2 additions & 2 deletions .github/workflows/type-lint.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# .github/workflows/tests.yml
# .github/workflows/type-lint.yml
name: Type-lint
on: push
jobs:
Expand All @@ -7,7 +7,7 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Setup python 3.10
- name: Setup python 3.11
uses: actions/setup-python@v4
with:
cache: 'pip'
Expand Down
54 changes: 49 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,54 @@
*/*.egg-info/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

/.idea/
/.tox/
/.coverage

.pytest_cache
*.DS_Store
__pycache__/
# Jupyter Notebook
.ipynb_checkpoints

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# IPython
profile_default/
ipython_config.py

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ SHELL = /bin/bash
.PHONY: style
style:
black .
flake8
python3 -m isort .
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ line_length = 79
multi_line_output = 3
include_trailing_comma = true

[tool.mypy]
python_version = "3.11"
warn_return_any = false
warn_unused_configs = true
ignore_errors = true # TODO: change this

# Black formatting
[tool.black]
line-length = 120
Expand All @@ -43,11 +49,14 @@ testpaths = ["tests"]
python_files = "test_*.py"

[project.optional-dependencies]
tests = [
dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh
"coverage>=7.2.3",
"flake8>=6.0.0",
"pytest<=7.4.4",
"pytest-cov>=4.0.0",
"pytest-mock==3.10.0",
"pytest_xdist",
"pre-commit>=3.2.0",
"mypy>=1.8.0",
"black[d]>=23.0.0",
"isort>=5.0.0",
]
18 changes: 9 additions & 9 deletions src/utils/datasets/group_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@
class GroupLabelDataset(Dataset):
class_group_by2 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]

@staticmethod
def check_class_groups(groups):
vals = [[] for _ in range(10)]
for g, group in enumerate(groups):
for i in group:
vals[i].append(g)
for v in vals:
assert len(v) == 1 # Check that this is the first time i is encountered

def __init__(self, dataset, class_groups=None):
self.dataset = dataset
self.class_labels = [i for i in range(len(class_groups))]
Expand All @@ -33,3 +24,12 @@ def __getitem__(self, item):

def __len__(self):
return len(self.dataset)

@staticmethod
def check_class_groups(groups):
vals = [[] for _ in range(10)]
for g, group in enumerate(groups):
for i in group:
vals[i].append(g)
for v in vals:
assert len(v) == 1 # Check that this is the first time i is encountered
24 changes: 12 additions & 12 deletions src/utils/datasets/mark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,6 @@


class MarkDataset(Dataset):
def get_mark_sample_ids(self):
indices = []
cls = self.cls_to_mark
prob = self.mark_prob
for i in range(len(self.dataset)):
x, y = self.dataset[i]
if y == cls:
rnd = torch.rand(1)
if rnd < prob:
indices.append(i)
return torch.tensor(indices, dtype=torch.int)

def __init__(self, dataset, p=0.3, cls_to_mark=2, only_train=False):
super().__init__()
self.class_labels = dataset.class_labels
Expand Down Expand Up @@ -54,6 +42,18 @@ def __getitem__(self, item):
else:
return x, y

def get_mark_sample_ids(self):
indices = []
cls = self.cls_to_mark
prob = self.mark_prob
for i in range(len(self.dataset)):
x, y = self.dataset[i]
if y == cls:
rnd = torch.rand(1)
if rnd < prob:
indices.append(i)
return torch.tensor(indices, dtype=torch.int)

def mark_image_contour(self, x):
x = self.dataset.inverse_transform(x)
mask = torch.zeros_like(x[0])
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


@pytest.fixture()
def dataset():
def load_dataset():
x = torch.stack([torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)])
y = torch.tensor([0, 1, 0]).long()
return torch.utils.data.TensorDataset(x, y)
11 changes: 9 additions & 2 deletions tests/utils/test_corrupt_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@


@pytest.mark.utils
def test_corrupt_label_dataset(dataset):
@pytest.mark.parametrize(
"dataset, n_expected",
[
("load_dataset", 2),
],
)
def test_corrupt_label_dataset(dataset, n_expected, request):
dataset = request.getfixturevalue(dataset)
# cl_dataset = CorruptLabelDataset(dataset)
assert 2 == 2
assert 2 == n_expected
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ description = Run type checking
base_python = py311
deps =
{[testenv]deps}
mypy==0.982
mypy==1.9.0
commands =
python3 -m mypy src

0 comments on commit b8ca7be

Please sign in to comment.