Skip to content
This repository has been archived by the owner on Jan 18, 2024. It is now read-only.

Commit

Permalink
Implement count_min and count_max
Browse files Browse the repository at this point in the history
  • Loading branch information
ethho committed Dec 10, 2023
1 parent b6cf232 commit 02d81c5
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 44 deletions.
4 changes: 4 additions & 0 deletions datajoint_file_validator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ class Config:
"""Config class for the application"""

allow_eval: bool = True
debug: bool = True


config = Config()
61 changes: 36 additions & 25 deletions datajoint_file_validator/constraint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Any
from ..config import Config
from ..config import config
from ..snapshot import Snapshot
from ..result import ValidationResult

Schema = Any

Expand All @@ -9,49 +11,60 @@
class Constraint:
"""A single constraint that evaluates True or False for a fileset."""

pass
def validate(self, snapshot: Snapshot) -> ValidationResult:
"""Validate a snapshot against a single constraint."""
raise NotImplementedError("Subclass of Constraint must implement validate() method.")

@property
def name(self):
_name = getattr(self, "_name", None)
return _name if _name else self.__class__.__name__


@dataclass
class CountMinConstraint(Constraint):
"""Constraint for `count_min`."""

val: int

def to_schema(self) -> Schema:
"""Convert this constraint to a Cerberus schema."""
raise NotImplementedError()
return {"minlength": self.val}
def validate(self, snapshot: Snapshot) -> ValidationResult:
status = len(snapshot) >= self.val
return ValidationResult(
status=status,
message=None if status else f"constraint `{self.name}` failed: {len(snapshot)} < {self.val}",
context=dict(snapshot=snapshot, constraint=self)
)


@dataclass
class CountMinConstraint(Constraint):
"""Constraint for `count_min`."""

class CountMaxConstraint(Constraint):
"""Constraint for `count_max`."""
val: int

def to_schema(self) -> Schema:
"""Convert this constraint to a Cerberus schema."""
raise NotImplementedError()
return {"minlength": self.val}
def validate(self, snapshot: Snapshot) -> ValidationResult:
status = len(snapshot) <= self.val
return ValidationResult(
status=status,
message=None if status else f"constraint `{self.name}` failed: {len(snapshot)} > {self.val}",
context=dict(snapshot=snapshot, constraint=self)
)


@dataclass
class RegexConstraint(Constraint):
"""Constraint for `regex`."""

val: str
class SchemaConvertibleConstraint(Constraint):

def to_schema(self) -> Schema:
"""Convert this constraint to a Cerberus schema."""
raise NotImplementedError("Subclass of SchemaConvertibleConstraint must implement to_schema() method.")

def validate(self, snapshot: Snapshot) -> ValidationResult:
"""Validate a snapshot against a single constraint."""
breakpoint()
raise NotImplementedError()
return {"regex": self.val}


@dataclass
class RegexConstraint(Constraint):
class RegexConstraint(SchemaConvertibleConstraint):
"""Constraint for `regex`."""

val: str

def to_schema(self) -> Schema:
Expand All @@ -63,12 +76,10 @@ def to_schema(self) -> Schema:
@dataclass
class EvalConstraint(Constraint):
"""Constraint for `eval`."""

val: str

def to_schema(self) -> Schema:
"""Convert this constraint to a Cerberus schema."""
if not Config.allow_eval:
def validate(self, snapshot: Snapshot) -> ValidationResult:
if not config.allow_eval:
raise DJFileValidatorError(
"Eval constraint is not allowed. "
"Set `Config.allow_eval = True` to allow."
Expand Down
26 changes: 24 additions & 2 deletions datajoint_file_validator/manifest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
import itertools
import yaml
from .snapshot import PathLike
from .snapshot import PathLike, FileMetadata
from .query import Query, GlobQuery, DEFAULT_QUERY
from .yaml import read_yaml
from .constraint import Constraint, CONSTRAINT_MAP
from .error import DJFileValidatorError
from .result import ValidationResult
from .snapshot import Snapshot
from .config import config


@dataclass
Expand All @@ -18,6 +22,22 @@ class Rule:
constraints: List[Constraint] = field(default_factory=list)
query: Query = field(default_factory=GlobQuery)

@staticmethod
def validate_constraint(file: FileMetadata, constraint: Constraint) -> bool:
"""Validate a single constraint."""
print(f"Validating constraint {constraint} on file {file}")
return constraint.validate(file)

def validate(self, snapshot: Snapshot) -> Dict[str, ValidationResult]:
filtered_snapshot: Snapshot = self.query.filter(snapshot)
if self.query.path == DEFAULT_QUERY and config.debug:
assert filtered_snapshot == snapshot
results = list(map(lambda constraint: constraint.validate(snapshot), self.constraints))
return {
constraint.name: result
for constraint, result in zip(self.constraints, results)
}

@staticmethod
def compile_query(raw: Any) -> "Query":
assert isinstance(raw, str)
Expand All @@ -28,7 +48,9 @@ def compile_constraint(name: str, rule: Any) -> "Constraint":
if name not in CONSTRAINT_MAP:
raise DJFileValidatorError(f"Unknown constraint: {name}")
try:
return CONSTRAINT_MAP[name](rule)
constraint = CONSTRAINT_MAP[name](rule)
constraint._name= name
return constraint
except DJFileValidatorError as e:
raise DJFileValidatorError(f"Error parsing constraint {name}: {e}")

Expand Down
6 changes: 3 additions & 3 deletions datajoint_file_validator/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def find_matching_paths(


def find_matching_files_gen(snapshot: Snapshot, patterns):
filenames = (file.get('path') for file in snapshot)
filenames = [file.get('path') for file in snapshot]
return (
file for file in snapshot
if file.get('path') in set(find_matching_paths(filenames, patterns))
)


def find_matching_files(snapshot: Snapshot, path: str):
return list(find_matching_files_gen)
def find_matching_files(snapshot: Snapshot, patterns):
return list(find_matching_files_gen(snapshot, patterns))
13 changes: 10 additions & 3 deletions datajoint_file_validator/result.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from dataclasses import dataclass
from typing import Dict, Any
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
import cerberus


@dataclass
class ValidationResult:
status: bool
# TODO
errors: Any
message: Any
context: Optional[Dict[str, Any]] = field(default_factory=dict)

@classmethod
def from_validator(cls, v: cerberus.Validator):
return cls(status=v.status, errors=v.errors)

def __repr__(self):
return f"ValidationResult(status={self.status}, message={self.message})"

def __bool__(self) -> bool:
return self.status
12 changes: 7 additions & 5 deletions datajoint_file_validator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@

@dataclass
class FileMetadata:
"""Metadata for a file."""
"""
Metadata for a file.
TODO: use wcmatch.Path instead to glob from starting directory
"""

name: str
path: str
path: str = field(init=False)
abs_path: str
rel_path: str
size: int
Expand All @@ -25,8 +29,7 @@ class FileMetadata:
_path: Optional[Path] = field(default=None, repr=False)

def __post_init__(self):
# self.id = f'{self.phrase}_{self.word_type.name.lower()}'
pass
self.path = self.rel_path

@staticmethod
def to_iso_8601(time_ns: int):
Expand All @@ -39,7 +42,6 @@ def from_path(cls, path: Path) -> "FileMetadata":
return cls(
name=path.name,
rel_path=str(path.relative_to(path.parent)),
path=str(path),
abs_path=str(path),
size=path.stat().st_size,
type="file" if path.is_file() else "directory",
Expand Down
7 changes: 2 additions & 5 deletions datajoint_file_validator/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
from .manifest import Manifest, Rule
from .snapshot import Snapshot, create_snapshot, PathLike
from .result import ValidationResult
from .query import DEFAULT_QUERY


def validate_rule(rule: Rule, snapshot: Snapshot):
fs: Snapshot = rule.query.filter(snapshot)
breakpoint()


def validate_snapshot(
snapshot: Snapshot, manifest_path: PathLike, verbose=False, raise_err=False
Expand All @@ -34,7 +31,7 @@ def validate_snapshot(
A dictionary with the validation result.
"""
manifest = Manifest.from_yaml(manifest_path)
results = list(map(lambda rule: validate_rule(rule, snapshot), manifest.rules))
results = list(map(lambda rule: rule.validate(snapshot), manifest.rules))
return results


Expand Down
2 changes: 1 addition & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
import datajoint_file_validator as djfval


def test_glob_query(snapshot):
def test_glob_query():
pass

0 comments on commit 02d81c5

Please sign in to comment.