diff --git a/bids2table/entities.py b/bids2table/entities.py index 316fcdc..4f22dc5 100644 --- a/bids2table/entities.py +++ b/bids2table/entities.py @@ -4,29 +4,37 @@ import re import warnings -from dataclasses import asdict, dataclass, field, fields +from dataclasses import asdict, dataclass, field, fields, make_dataclass from functools import lru_cache from pathlib import Path from types import MappingProxyType -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +import bidsschematools.schema import pandas as pd from elbow.typing import StrOrPath -from typing_extensions import get_args, get_origin - -BIDS_DATATYPES = ( - "anat", - "beh", - "dwi", - "eeg", - "fmap", - "func", - "ieeg", - "meg", - "micr", - "perf", - "pet", -) +from typing_extensions import Self, get_args, get_origin + +_bids_schema_path: Optional[str] = None + + +def set_bids_schema_path(path: Optional[str]): + """ + Set the path to the BIDS schema. + """ + global _bids_schema_path + _bids_schema_path = path + + +def _get_bids_schema() -> Any: + """ + Get the BIDS schema. + """ + global _bids_schema_path + return bidsschematools.schema.load_schema(schema_path=_bids_schema_path) + + +BIDS_DATATYPES = tuple(o.value for o in _get_bids_schema().objects.datatypes.values()) def bids_field( @@ -58,7 +66,7 @@ def bids_field( @dataclass -class BIDSEntities: +class _BIDSEntitiesBase: """ A dataclass representing known BIDS entities. @@ -72,52 +80,7 @@ class BIDSEntities: sub: str = bids_field(name="subject", display_name="Subject", required=True) ses: Optional[str] = bids_field(name="session", display_name="Session") - sample: Optional[str] = bids_field(name="sample", display_name="Sample") - task: Optional[str] = bids_field(name="task", display_name="Task") - acq: Optional[str] = bids_field(name="acquisition", display_name="Acquisition") - ce: Optional[str] = bids_field( - name="ceagent", display_name="Contrast Enhancing Agent" - ) - trc: Optional[str] = bids_field(name="tracer", display_name="Tracer") - stain: Optional[str] = bids_field(name="stain", display_name="Stain") - rec: Optional[str] = bids_field( - name="reconstruction", display_name="Reconstruction" - ) - dir: Optional[str] = bids_field( - name="direction", display_name="Phase-Encoding Direction" - ) - run: Optional[int] = bids_field(name="run", display_name="Run") - mod: Optional[str] = bids_field( - name="modality", display_name="Corresponding Modality" - ) - echo: Optional[int] = bids_field(name="echo", display_name="Echo") - flip: Optional[int] = bids_field(name="flip", display_name="Flip Angle") - inv: Optional[int] = bids_field(name="inversion", display_name="Inversion Time") - mt: Optional[str] = bids_field( - name="mtransfer", - display_name="Magnetization Transfer", - allowed_values={"on", "off"}, - ) - part: Optional[str] = bids_field( - name="part", - display_name="Part", - allowed_values={"mag", "phase", "real", "imag"}, - ) - proc: Optional[str] = bids_field( - name="processing", display_name="Processed (on device)" - ) - hemi: Optional[str] = bids_field( - name="hemisphere", display_name="Hemisphere", allowed_values={"L", "R"} - ) - space: Optional[str] = bids_field(name="space", display_name="Space") - split: Optional[int] = bids_field(name="split", display_name="Split") - recording: Optional[str] = bids_field(name="recording", display_name="Recording") - chunk: Optional[int] = bids_field(name="chunk", display_name="Chunk") - atlas: Optional[str] = bids_field(name="atlas", display_name="Atlas") - res: Optional[str] = bids_field(name="resolution", display_name="Resolution") - den: Optional[str] = bids_field(name="density", display_name="Density") - label: Optional[str] = bids_field(name="label", display_name="Label") - desc: Optional[str] = bids_field(name="description", display_name="Description") + datatype: Optional[str] = bids_field( name="datatype", display_name="Data type", allowed_values=BIDS_DATATYPES ) @@ -248,7 +211,7 @@ def to_path( def with_update( self, entitities: Optional[Dict[str, Any]] = None, **kwargs - ) -> "BIDSEntities": + ) -> Self: """ Create a new instance with updated entities. """ @@ -257,7 +220,39 @@ def with_update( data.update(entitities) if kwargs: data.update(kwargs) - return BIDSEntities.from_dict(data) + return self.__class__.from_dict(data) + + +def make_bids_field( + entity_schema: Dict[str, Any], +) -> Tuple[str, Any, Any]: + """ + BIDS entity dataclass field. + """ + metadata = { + "name": entity_schema["name"], + "display_name": entity_schema["display_name"], + "allowed_values": entity_schema.get("enum"), + } + type_ = {"index": int}.get(entity_schema["format"], str) + field_ = field(default=None, metadata=metadata) + return entity_schema["name"], Optional[type_], field_ + + +_sorted_entities = [ + _get_bids_schema().objects.entities[field_name] + for field_name in _get_bids_schema().rules.entities +] +BIDSEntities: Any = make_dataclass( + "BIDSEntities", + [ + make_bids_field(dict(f)) + for f in _sorted_entities + if f.name not in {f.name for f in fields(_BIDSEntitiesBase)} + ], + bases=(_BIDSEntitiesBase,), +) +del _sorted_entities def _get_type(alias: Any) -> type: diff --git a/pyproject.toml b/pyproject.toml index 9d65b38..c687003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "elbow", "nibabel", "pandas", + "bidsschematools", + "typing_extensions", ] dynamic = ["version"] diff --git a/tests/test_bids2table.py b/tests/test_bids2table.py index 4f58642..d97cb06 100644 --- a/tests/test_bids2table.py +++ b/tests/test_bids2table.py @@ -4,6 +4,7 @@ import pytest from bids2table import bids2table +from bids2table.entities import ENTITY_NAMES_TO_KEYS BIDS_EXAMPLES = Path(__file__).parent.parent / "bids-examples" @@ -35,7 +36,7 @@ def test_bids2table(tmp_path: Path, persistent: bool, with_meta: bool): tab = bids2table( root=root, with_meta=with_meta, persistent=persistent, index_path=index_path ) - assert tab.shape == (128, 40) + assert tab.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8) if not with_meta: assert tab.loc[0, "meta__json"] is None diff --git a/tests/test_main.py b/tests/test_main.py index b4df3f4..9a399bd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ import pytest from bids2table import __main__ as cli +from bids2table.entities import ENTITY_NAMES_TO_KEYS BIDS_EXAMPLES = Path(__file__).parent.parent / "bids-examples" @@ -37,7 +38,7 @@ def test_main(tmp_path: Path): cli.main() df = pd.read_parquet(output) - assert df.shape == (128, 40) + assert df.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8) if __name__ == "__main__": diff --git a/tests/test_table.py b/tests/test_table.py index 6b9a0a6..cbe8c2f 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -6,6 +6,7 @@ from bids2table import bids2table from bids2table.table import ( + ENTITY_NAMES_TO_KEYS, BIDSTable, flat_to_multi_columns, join_bids_path, @@ -32,13 +33,13 @@ def tab_no_meta() -> BIDSTable: def test_table(tab: BIDSTable): - assert tab.shape == (128, 40) + assert tab.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8) groups = tab.nested.columns.unique(0).tolist() assert groups == ["ds", "ent", "meta", "finfo"] assert tab.ds.shape == (128, 4) - assert tab.ent.shape == (128, 32) + assert tab.ent.shape == (128, len(ENTITY_NAMES_TO_KEYS)) assert tab.meta.shape == (128, 1) assert tab.flat_meta.shape == (128, 2) assert tab.finfo.shape == (128, 3) @@ -231,7 +232,7 @@ def test_join_bids_path( expected: str, ): path = join_bids_path(entities, prefix=prefix, valid_only=valid_only) - assert str(path) == expected + assert Path(path).as_posix() == expected if __name__ == "__main__":