Skip to content

Commit

Permalink
Fetch BIDS information from schema (#34)
Browse files Browse the repository at this point in the history
* Generate BIDS entity information from schema

* Sort BIDS entities according to schema

* Apply suggestions from code review

Co-authored-by: Chris Markiewicz <[email protected]>

* Revert test changes for run type int

* Run black

* Self

* typing

---------

Co-authored-by: Chris Markiewicz <[email protected]>
  • Loading branch information
nx10 and effigies authored Jun 21, 2024
1 parent b22f036 commit 7650e8c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 71 deletions.
127 changes: 61 additions & 66 deletions bids2table/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,37 @@

import re
import warnings
from dataclasses import asdict, dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields, make_dataclass
from functools import lru_cache
from pathlib import Path
from types import MappingProxyType
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import bidsschematools.schema
import pandas as pd
from elbow.typing import StrOrPath
from typing_extensions import get_args, get_origin

BIDS_DATATYPES = (
"anat",
"beh",
"dwi",
"eeg",
"fmap",
"func",
"ieeg",
"meg",
"micr",
"perf",
"pet",
)
from typing_extensions import Self, get_args, get_origin

_bids_schema_path: Optional[str] = None


def set_bids_schema_path(path: Optional[str]):
"""
Set the path to the BIDS schema.
"""
global _bids_schema_path
_bids_schema_path = path


def _get_bids_schema() -> Any:
"""
Get the BIDS schema.
"""
global _bids_schema_path
return bidsschematools.schema.load_schema(schema_path=_bids_schema_path)


BIDS_DATATYPES = tuple(o.value for o in _get_bids_schema().objects.datatypes.values())


def bids_field(
Expand Down Expand Up @@ -58,7 +66,7 @@ def bids_field(


@dataclass
class BIDSEntities:
class _BIDSEntitiesBase:
"""
A dataclass representing known BIDS entities.
Expand All @@ -72,52 +80,7 @@ class BIDSEntities:

sub: str = bids_field(name="subject", display_name="Subject", required=True)
ses: Optional[str] = bids_field(name="session", display_name="Session")
sample: Optional[str] = bids_field(name="sample", display_name="Sample")
task: Optional[str] = bids_field(name="task", display_name="Task")
acq: Optional[str] = bids_field(name="acquisition", display_name="Acquisition")
ce: Optional[str] = bids_field(
name="ceagent", display_name="Contrast Enhancing Agent"
)
trc: Optional[str] = bids_field(name="tracer", display_name="Tracer")
stain: Optional[str] = bids_field(name="stain", display_name="Stain")
rec: Optional[str] = bids_field(
name="reconstruction", display_name="Reconstruction"
)
dir: Optional[str] = bids_field(
name="direction", display_name="Phase-Encoding Direction"
)
run: Optional[int] = bids_field(name="run", display_name="Run")
mod: Optional[str] = bids_field(
name="modality", display_name="Corresponding Modality"
)
echo: Optional[int] = bids_field(name="echo", display_name="Echo")
flip: Optional[int] = bids_field(name="flip", display_name="Flip Angle")
inv: Optional[int] = bids_field(name="inversion", display_name="Inversion Time")
mt: Optional[str] = bids_field(
name="mtransfer",
display_name="Magnetization Transfer",
allowed_values={"on", "off"},
)
part: Optional[str] = bids_field(
name="part",
display_name="Part",
allowed_values={"mag", "phase", "real", "imag"},
)
proc: Optional[str] = bids_field(
name="processing", display_name="Processed (on device)"
)
hemi: Optional[str] = bids_field(
name="hemisphere", display_name="Hemisphere", allowed_values={"L", "R"}
)
space: Optional[str] = bids_field(name="space", display_name="Space")
split: Optional[int] = bids_field(name="split", display_name="Split")
recording: Optional[str] = bids_field(name="recording", display_name="Recording")
chunk: Optional[int] = bids_field(name="chunk", display_name="Chunk")
atlas: Optional[str] = bids_field(name="atlas", display_name="Atlas")
res: Optional[str] = bids_field(name="resolution", display_name="Resolution")
den: Optional[str] = bids_field(name="density", display_name="Density")
label: Optional[str] = bids_field(name="label", display_name="Label")
desc: Optional[str] = bids_field(name="description", display_name="Description")

datatype: Optional[str] = bids_field(
name="datatype", display_name="Data type", allowed_values=BIDS_DATATYPES
)
Expand Down Expand Up @@ -248,7 +211,7 @@ def to_path(

def with_update(
self, entitities: Optional[Dict[str, Any]] = None, **kwargs
) -> "BIDSEntities":
) -> Self:
"""
Create a new instance with updated entities.
"""
Expand All @@ -257,7 +220,39 @@ def with_update(
data.update(entitities)
if kwargs:
data.update(kwargs)
return BIDSEntities.from_dict(data)
return self.__class__.from_dict(data)


def make_bids_field(
entity_schema: Dict[str, Any],
) -> Tuple[str, Any, Any]:
"""
BIDS entity dataclass field.
"""
metadata = {
"name": entity_schema["name"],
"display_name": entity_schema["display_name"],
"allowed_values": entity_schema.get("enum"),
}
type_ = {"index": int}.get(entity_schema["format"], str)
field_ = field(default=None, metadata=metadata)
return entity_schema["name"], Optional[type_], field_


_sorted_entities = [
_get_bids_schema().objects.entities[field_name]
for field_name in _get_bids_schema().rules.entities
]
BIDSEntities: Any = make_dataclass(
"BIDSEntities",
[
make_bids_field(dict(f))
for f in _sorted_entities
if f.name not in {f.name for f in fields(_BIDSEntitiesBase)}
],
bases=(_BIDSEntitiesBase,),
)
del _sorted_entities


def _get_type(alias: Any) -> type:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ dependencies = [
"elbow",
"nibabel",
"pandas",
"bidsschematools",
"typing_extensions",
]
dynamic = ["version"]

Expand Down
3 changes: 2 additions & 1 deletion tests/test_bids2table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from bids2table import bids2table
from bids2table.entities import ENTITY_NAMES_TO_KEYS

BIDS_EXAMPLES = Path(__file__).parent.parent / "bids-examples"

Expand Down Expand Up @@ -35,7 +36,7 @@ def test_bids2table(tmp_path: Path, persistent: bool, with_meta: bool):
tab = bids2table(
root=root, with_meta=with_meta, persistent=persistent, index_path=index_path
)
assert tab.shape == (128, 40)
assert tab.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8)

if not with_meta:
assert tab.loc[0, "meta__json"] is None
Expand Down
3 changes: 2 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from bids2table import __main__ as cli
from bids2table.entities import ENTITY_NAMES_TO_KEYS

BIDS_EXAMPLES = Path(__file__).parent.parent / "bids-examples"

Expand Down Expand Up @@ -37,7 +38,7 @@ def test_main(tmp_path: Path):
cli.main()

df = pd.read_parquet(output)
assert df.shape == (128, 40)
assert df.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bids2table import bids2table
from bids2table.table import (
ENTITY_NAMES_TO_KEYS,
BIDSTable,
flat_to_multi_columns,
join_bids_path,
Expand All @@ -32,13 +33,13 @@ def tab_no_meta() -> BIDSTable:


def test_table(tab: BIDSTable):
assert tab.shape == (128, 40)
assert tab.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8)

groups = tab.nested.columns.unique(0).tolist()
assert groups == ["ds", "ent", "meta", "finfo"]

assert tab.ds.shape == (128, 4)
assert tab.ent.shape == (128, 32)
assert tab.ent.shape == (128, len(ENTITY_NAMES_TO_KEYS))
assert tab.meta.shape == (128, 1)
assert tab.flat_meta.shape == (128, 2)
assert tab.finfo.shape == (128, 3)
Expand Down Expand Up @@ -231,7 +232,7 @@ def test_join_bids_path(
expected: str,
):
path = join_bids_path(entities, prefix=prefix, valid_only=valid_only)
assert str(path) == expected
assert Path(path).as_posix() == expected


if __name__ == "__main__":
Expand Down

0 comments on commit 7650e8c

Please sign in to comment.