Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch BIDS information from schema #34

Merged
merged 7 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 55 additions & 63 deletions bids2table/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,38 @@

import re
import warnings
from dataclasses import asdict, dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields, make_dataclass
from functools import lru_cache
from pathlib import Path
from types import MappingProxyType
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
nx10 marked this conversation as resolved.
Show resolved Hide resolved

import bidsschematools.schema
import pandas as pd
from elbow.typing import StrOrPath
from typing_extensions import get_args, get_origin

BIDS_DATATYPES = (
"anat",
"beh",
"dwi",
"eeg",
"fmap",
"func",
"ieeg",
"meg",
"micr",
"perf",
"pet",
_bids_schema_path: Optional[str] = None


def set_bids_schema_path(path: Optional[str]):
"""
Set the path to the BIDS schema.
"""
global _bids_schema_path
_bids_schema_path = path


def _get_bids_schema() -> Any:
"""
Get the BIDS schema.
"""
global _bids_schema_path
return bidsschematools.schema.load_schema(schema_path=_bids_schema_path)


BIDS_DATATYPES = tuple(
o.value for o in dict(_get_bids_schema().objects.datatypes).values()
nx10 marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -58,7 +68,7 @@ def bids_field(


@dataclass
class BIDSEntities:
class _BIDSEntitiesBase:
"""
A dataclass representing known BIDS entities.

Expand All @@ -72,52 +82,7 @@ class BIDSEntities:

sub: str = bids_field(name="subject", display_name="Subject", required=True)
ses: Optional[str] = bids_field(name="session", display_name="Session")
sample: Optional[str] = bids_field(name="sample", display_name="Sample")
task: Optional[str] = bids_field(name="task", display_name="Task")
acq: Optional[str] = bids_field(name="acquisition", display_name="Acquisition")
ce: Optional[str] = bids_field(
name="ceagent", display_name="Contrast Enhancing Agent"
)
trc: Optional[str] = bids_field(name="tracer", display_name="Tracer")
stain: Optional[str] = bids_field(name="stain", display_name="Stain")
rec: Optional[str] = bids_field(
name="reconstruction", display_name="Reconstruction"
)
dir: Optional[str] = bids_field(
name="direction", display_name="Phase-Encoding Direction"
)
run: Optional[int] = bids_field(name="run", display_name="Run")
mod: Optional[str] = bids_field(
name="modality", display_name="Corresponding Modality"
)
echo: Optional[int] = bids_field(name="echo", display_name="Echo")
flip: Optional[int] = bids_field(name="flip", display_name="Flip Angle")
inv: Optional[int] = bids_field(name="inversion", display_name="Inversion Time")
mt: Optional[str] = bids_field(
name="mtransfer",
display_name="Magnetization Transfer",
allowed_values={"on", "off"},
)
part: Optional[str] = bids_field(
name="part",
display_name="Part",
allowed_values={"mag", "phase", "real", "imag"},
)
proc: Optional[str] = bids_field(
name="processing", display_name="Processed (on device)"
)
hemi: Optional[str] = bids_field(
name="hemisphere", display_name="Hemisphere", allowed_values={"L", "R"}
)
space: Optional[str] = bids_field(name="space", display_name="Space")
split: Optional[int] = bids_field(name="split", display_name="Split")
recording: Optional[str] = bids_field(name="recording", display_name="Recording")
chunk: Optional[int] = bids_field(name="chunk", display_name="Chunk")
atlas: Optional[str] = bids_field(name="atlas", display_name="Atlas")
res: Optional[str] = bids_field(name="resolution", display_name="Resolution")
den: Optional[str] = bids_field(name="density", display_name="Density")
label: Optional[str] = bids_field(name="label", display_name="Label")
desc: Optional[str] = bids_field(name="description", display_name="Description")

datatype: Optional[str] = bids_field(
name="datatype", display_name="Data type", allowed_values=BIDS_DATATYPES
)
Expand Down Expand Up @@ -248,7 +213,7 @@ def to_path(

def with_update(
self, entitities: Optional[Dict[str, Any]] = None, **kwargs
) -> "BIDSEntities":
) -> "_BIDSEntitiesBase":
nx10 marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a new instance with updated entities.
"""
Expand All @@ -257,7 +222,34 @@ def with_update(
data.update(entitities)
if kwargs:
data.update(kwargs)
return BIDSEntities.from_dict(data)
return self.__class__.from_dict(data)


def make_bids_field(
entity_schema: Dict[str, Any],
) -> Tuple[str, Any, Any]:
"""
BIDS entity dataclass field.
"""
metadata = {
"name": entity_schema["name"],
"display_name": entity_schema["display_name"],
"allowed_values": entity_schema.get("enum"),
}
type_ = {"integer": int}.get(entity_schema["type"], str)
nx10 marked this conversation as resolved.
Show resolved Hide resolved
field_ = field(default=None, metadata=metadata)
return entity_schema["name"], Optional[type_], field_


BIDSEntities: Any = make_dataclass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BIDSEntities: Any = make_dataclass(
BIDSEntities: type = make_dataclass(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had it like this, but mypy does not like it at all (results in everything being flagged missing-attr). There is an open issue: python/mypy#6063

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool...

"BIDSEntities",
[
make_bids_field(dict(f))
for f in dict(_get_bids_schema().objects.entities).values()
if f.name not in {f.name for f in fields(_BIDSEntitiesBase)}
],
bases=(_BIDSEntitiesBase,),
)


def _get_type(alias: Any) -> type:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"elbow",
"nibabel",
"pandas",
"bidsschematools",
]
dynamic = ["version"]

Expand Down
3 changes: 2 additions & 1 deletion tests/test_bids2table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from bids2table import bids2table
from bids2table.entities import ENTITY_NAMES_TO_KEYS

BIDS_EXAMPLES = Path(__file__).parent.parent / "bids-examples"

Expand Down Expand Up @@ -35,7 +36,7 @@ def test_bids2table(tmp_path: Path, persistent: bool, with_meta: bool):
tab = bids2table(
root=root, with_meta=with_meta, persistent=persistent, index_path=index_path
)
assert tab.shape == (128, 40)
assert tab.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8)

if not with_meta:
assert tab.loc[0, "meta__json"] is None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
sub="A01",
ses="B02",
task="rest",
run=1,
run="1",
nx10 marked this conversation as resolved.
Show resolved Hide resolved
datatype="func",
suffix="bold",
ext=".nii",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from bids2table import __main__ as cli
from bids2table.entities import ENTITY_NAMES_TO_KEYS

BIDS_EXAMPLES = Path(__file__).parent.parent / "bids-examples"

Expand Down Expand Up @@ -37,7 +38,7 @@ def test_main(tmp_path: Path):
cli.main()

df = pd.read_parquet(output)
assert df.shape == (128, 40)
assert df.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bids2table import bids2table
from bids2table.table import (
ENTITY_NAMES_TO_KEYS,
BIDSTable,
flat_to_multi_columns,
join_bids_path,
Expand All @@ -32,13 +33,13 @@ def tab_no_meta() -> BIDSTable:


def test_table(tab: BIDSTable):
assert tab.shape == (128, 40)
assert tab.shape == (128, len(ENTITY_NAMES_TO_KEYS) + 8)

groups = tab.nested.columns.unique(0).tolist()
assert groups == ["ds", "ent", "meta", "finfo"]

assert tab.ds.shape == (128, 4)
assert tab.ent.shape == (128, 32)
assert tab.ent.shape == (128, len(ENTITY_NAMES_TO_KEYS))
assert tab.meta.shape == (128, 1)
assert tab.flat_meta.shape == (128, 2)
assert tab.finfo.shape == (128, 3)
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_table_files(tab: BIDSTable):
}

ents = file.entities
assert (ents.sub, ents.task, ents.run) == ("01", "balloonanalogrisktask", 1)
assert (ents.sub, ents.task, ents.run) == ("01", "balloonanalogrisktask", "01")
nx10 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
Expand Down
Loading