Skip to content

Commit

Permalink
Merge branch 'main' into param_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
schuenke committed Jan 20, 2025
2 parents e46d8cc + 2e0eaf5 commit bfe3a21
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Pytest coverage comment
id: coverageComment
uses: MishaKav/[email protected].52
uses: MishaKav/[email protected].53
with:
pytest-coverage-path: ./pytest-coverage.txt
junitxml-path: ./pytest.xml
Expand Down
18 changes: 14 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-docstring-first
- id: check-merge-conflict
Expand All @@ -14,19 +14,19 @@ repos:
- id: mixed-line-ending

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.8.1
hooks:
- id: ruff # linter
args: [--fix]
- id: ruff-format # formatter

- repo: https://github.com/crate-ci/typos
rev: v1.21.0
rev: typos-dict-v0.11.37
hooks:
- id: typos

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
rev: v1.13.0
hooks:
- id: mypy
pass_filenames: false
Expand All @@ -38,3 +38,13 @@ repos:
- pypulseq>=1.4.2
- types-pyyaml
- types-tqdm

ci:
autofix_commit_msg: |
[pre-commit] auto fixes from pre-commit hooks
autofix_prs: true
autoupdate_branch: ""
autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate"
autoupdate_schedule: monthly
skip: [mypy]
submodules: false
12 changes: 1 addition & 11 deletions src/bmctool/parameters/CESTPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class CESTPool(Pool):
"""Class to store CESTPool parameters."""

__slots__ = ['_r1', '_r2', '_k', '_f', '_dw']
__slots__ = ['_dw', '_f', '_k', '_r1', '_r2']

def __init__(
self,
Expand Down Expand Up @@ -44,16 +44,6 @@ def __init__(
super().__init__(f=f, dw=dw, r1=r1, r2=r2, t1=t1, t2=t2)
self.k = k

def __dict__(self):
"""Return dictionary representation of CESTPool."""
return {
'f': self.f,
'r1': self.r1,
'r2': self.r2,
'k': self.k,
'dw': self.dw,
}

@property
def k(self) -> float:
"""Exchange rate [Hz]."""
Expand Down
13 changes: 1 addition & 12 deletions src/bmctool/parameters/MTPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class MTPool(Pool):
"""Class to store MTPool parameters."""

valid_lineshapes: typing.ClassVar[list[str]] = ['lorentzian', 'superlorentzian']
__slots__ = ['_r1', '_r2', '_k', '_f', '_dw', '_lineshape']
__slots__ = ['_dw', '_f', '_k', '_lineshape', '_r1', '_r2']

def __init__(
self,
Expand Down Expand Up @@ -52,17 +52,6 @@ def __init__(
self.k = k
self.lineshape = lineshape

def __dict__(self):
"""Return dictionary representation of MTPool."""
return {
'f': self.f,
'r1': self.r1,
'r2': self.r2,
'k': self.k,
'dw': self.dw,
'lineshape': self.lineshape,
}

@property
def k(self) -> float:
"""Exchange rate [Hz]."""
Expand Down
13 changes: 4 additions & 9 deletions src/bmctool/parameters/Options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Options:
"""Class to store simulation options."""

__slots__ = ['_verbose', '_reset_init_mag', '_scale', '_max_pulse_samples']
__slots__ = ['_max_pulse_samples', '_reset_init_mag', '_scale', '_verbose']

def __init__(
self,
Expand Down Expand Up @@ -45,14 +45,9 @@ def __eq__(self, other: object) -> bool:
return all(getter(self) == getter(other) for getter in attr_getters)
return False

def __dict__(self):
"""Return dictionary representation of Options."""
return {
'verbose': self.verbose,
'reset_init_mag': self.reset_init_mag,
'scale': self.scale,
'max_pulse_samples': self.max_pulse_samples,
}
def to_dict(self) -> dict:
"""Return dictionary representation with leading underscores removed."""
return {slot.lstrip('_'): getattr(self, slot) for slot in self.__slots__}

@property
def verbose(self) -> bool:
Expand Down
11 changes: 6 additions & 5 deletions src/bmctool/parameters/Parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def from_dict(cls, config: dict) -> Self:
system = System(
**{rename.get(key, key): value for key, value in config.items() if rename.get(key, key) in sys_keys}
)

options = Options(
**{rename.get(key, key): value for key, value in config.items() if rename.get(key, key) in opt_keys}
)
Expand Down Expand Up @@ -149,16 +150,16 @@ def to_yaml(self, yaml_file: str | Path) -> None:
Path to yaml file.
"""
if self.cest_pools:
cest_dict = {f'cest_{ii + 1}': pool.__dict__() for ii, pool in enumerate(self.cest_pools)}
cest_dict = {f'cest_{ii + 1}': pool.to_dict() for ii, pool in enumerate(self.cest_pools)}

with Path(yaml_file).open('w') as file:
yaml.dump({'water_pool': self.water_pool.__dict__()}, file)
yaml.dump({'water_pool': self.water_pool.to_dict()}, file)
if self.mt_pool:
yaml.dump({'mt_pool': self.mt_pool.__dict__()}, file)
yaml.dump({'mt_pool': self.mt_pool.to_dict()}, file)
if self.cest_pools:
yaml.dump({'cest_pool': cest_dict}, file)
yaml.dump(self.system.__dict__(), file)
yaml.dump(self.options.__dict__(), file)
yaml.dump(self.system.to_dict(), file)
yaml.dump(self.options.to_dict(), file)

file.close()

Expand Down
6 changes: 5 additions & 1 deletion src/bmctool/parameters/Pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Pool(ABC):
"""Base Class for Pools."""

__slots__ = ['_r1', '_r2', '_f', '_dw']
__slots__ = ['_dw', '_f', '_r1', '_r2']

def __init__(
self,
Expand Down Expand Up @@ -68,6 +68,10 @@ def __eq__(self, other: object) -> bool:
return all(getter(self) == getter(other) for getter in attr_getters)
return False

def to_dict(self) -> dict:
"""Return dictionary representation with leading underscores removed."""
return {slot.lstrip('_'): getattr(self, slot) for slot in self.__slots__}

@property
def r1(self) -> float:
"""R1 relaxation rate [Hz] (1/T1)."""
Expand Down
13 changes: 4 additions & 9 deletions src/bmctool/parameters/System.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class System:
"""Class to store system parameters."""

__slots__ = ['_b0', '_gamma', '_b0_inhom', '_rel_b1']
__slots__ = ['_b0', '_b0_inhom', '_gamma', '_rel_b1']

def __init__(
self,
Expand Down Expand Up @@ -43,14 +43,9 @@ def __eq__(self, other: object) -> bool:
return all(getter(self) == getter(other) for getter in attr_getters)
return False

def __dict__(self):
"""Return dictionary representation of System."""
return {
'b0': self.b0,
'gamma': self.gamma,
'b0_inhom': self.b0_inhom,
'rel_b1': self.rel_b1,
}
def to_dict(self) -> dict:
"""Return dictionary representation with leading underscores removed."""
return {slot.lstrip('_'): getattr(self, slot) for slot in self.__slots__}

@property
def b0(self) -> float:
Expand Down
12 changes: 4 additions & 8 deletions src/bmctool/parameters/WaterPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class WaterPool(Pool):
"""Class to store WaterPool parameters."""

__slots__ = ['_f', '_dw', '_r1', '_r2']
__slots__ = ['_dw', '_f', '_r1', '_r2']

def __init__(
self,
Expand Down Expand Up @@ -37,13 +37,9 @@ def __init__(
"""
super().__init__(f=f, dw=0, r1=r1, r2=r2, t1=t1, t2=t2)

def __dict__(self):
"""Return dictionary representation of WaterPool."""
return {
'f': self.f,
'r1': self.r1,
'r2': self.r2,
}
def to_dict(self) -> dict:
"""Return dictionary representation with leading underscores removed."""
return {slot.lstrip('_'): getattr(self, slot) for slot in self.__slots__ if slot != '_dw'}

@property
def dw(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion src/bmctool/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def plot_z(
ax2 = ax1.twinx()
ax2.set_ylim((round(min(mtr_asym) - 0.01, 2), round(max(mtr_asym) + 0.01, 2)))
ax2.set_ylabel('$MTR_{asym}$', color='y')
ax2.plot(offsets, mtr_asym, label='$MTR_{asym}$', color='y') # type: ignore
ax2.plot(offsets, mtr_asym, label='$MTR_{asym}$', color='y')
ax2.tick_params(axis='y', labelcolor='y')
fig.tight_layout()

Expand Down
4 changes: 2 additions & 2 deletions src/bmctool/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ def truthy_check(value: bool | int | float | str) -> bool:
"""Check if input value is truthy."""
if isinstance(value, str):
value = value.lower()
if value in {True, 1, 1.0, 'true'}:
if value in {True, 'true'}:
return True
elif value in {False, 0, 'false'}:
elif value in {False, 'false'}:
return False
raise ValueError('Input {value} cannot be converted to bool.')
5 changes: 3 additions & 2 deletions src/bmctool/utils/pulses/calc_power_equivalents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def calc_power_equivalent(
"""
amp = rf_pulse.signal / gamma_hz
duty_cycle = tp / (tp + td)
cwpe = np.sqrt(np.trapz(amp**2, rf_pulse.t) / tp * duty_cycle) # noqa: NPY201

return np.sqrt(np.trapz(amp**2, rf_pulse.t) / tp * duty_cycle) # type: ignore
return cwpe # type: ignore


def calc_amplitude_equivalent(
Expand Down Expand Up @@ -62,6 +63,6 @@ def calc_amplitude_equivalent(
Continuous wave amplitude equivalent value.
"""
duty_cycle = tp / (tp + td)
alpha_rad = np.trapz(rf_pulse.signal * gamma_hz * 360, rf_pulse.t) * np.pi / 180
alpha_rad = np.trapz(rf_pulse.signal * gamma_hz * 360, rf_pulse.t) * np.pi / 180 # noqa: NPY201

return alpha_rad / (gamma_hz * 2 * np.pi * tp) * duty_cycle # type: ignore
2 changes: 1 addition & 1 deletion src/bmctool/utils/pulses/make_hanning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def make_gauss_hanning(
rf_pulse = pp.make_gauss_pulse(flip_angle=flip_angle, duration=pulse_duration, system=system, phase_offset=0)
n_signal = rf_pulse.signal.size
hanning_shape = hanning(n_signal)
rf_pulse.signal = hanning_shape / np.trapz(hanning_shape, x=rf_pulse.t) * (flip_angle / (2 * np.pi))
rf_pulse.signal = hanning_shape / np.trapz(hanning_shape, x=rf_pulse.t) * (flip_angle / (2 * np.pi)) # noqa: NPY201
return rf_pulse # type: ignore

0 comments on commit bfe3a21

Please sign in to comment.