Skip to content

Commit

Permalink
refactor BlochMcConnellSolver and BMCTool
Browse files Browse the repository at this point in the history
reactivate mypy checks
fix wrong type hints
remove is_mt_active attribute
  • Loading branch information
schuenke committed Jan 10, 2024
1 parent 73ffe4b commit 460b8ce
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 40 deletions.
26 changes: 10 additions & 16 deletions src/bmctool/BMCTool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
from pathlib import Path
from types import SimpleNamespace

Expand Down Expand Up @@ -96,8 +95,6 @@ def __init__(
self.params = params
self.seq_file = seq_file
self.verbose = verbose
self.run_m0_scan = None
self.bm_solver = None

# read pulseq sequence
self.seq = pp.Sequence()
Expand All @@ -124,7 +121,7 @@ def run(self) -> None:
"""Start simulation process."""

current_adc = 0
accum_phase = 0
accum_phase = 0.0

# create initial magnezitation array with correct shape
mag = self.m_init[np.newaxis, :, np.newaxis]
Expand All @@ -146,7 +143,7 @@ def run(self) -> None:
def _simulate_block(
self,
block: SimpleNamespace,
current_adc: float,
current_adc: int,
accum_phase: float,
mag: np.ndarray,
) -> tuple[int, float, np.ndarray]:
Expand Down Expand Up @@ -202,22 +199,24 @@ def _simulate_block(

return current_adc, accum_phase, mag

def _handle_adc_event(self, current_adc, accum_phase, mag):
def _handle_adc_event(self, current_adc: int, accum_phase: float, mag: np.ndarray) -> tuple[int, float, np.ndarray]:
"""Handle ADC event: write current mag to output, reset phase and increase ADC counter."""

# write current magnetization to output
self.m_out[:, current_adc] = np.squeeze(mag)

# reset phase and increase ADC counter
accum_phase = 0
accum_phase = 0.0
current_adc += 1

# reset magnetization if reset_init_mag is True
if self.params.options.reset_init_mag:
mag = self.m_init[np.newaxis, :, np.newaxis]
return current_adc, accum_phase, mag

def _handle_rf_pulse(self, block, current_adc, accum_phase, mag):
def _handle_rf_pulse(
self, block: SimpleNamespace, current_adc: int, accum_phase: float, mag: np.ndarray
) -> tuple[int, float, np.ndarray]:
"""Handle RF pulse: simulate all steps of RF pulse and update phase."""

# resample amplitude and phase of RF pulse according to max_pulse_samples
Expand All @@ -244,7 +243,7 @@ def _handle_rf_pulse(self, block, current_adc, accum_phase, mag):

return current_adc, accum_phase, mag

def _handle_spoiler_gradient(self, block, mag):
def _handle_spoiler_gradient(self, block: SimpleNamespace, mag: np.ndarray) -> np.ndarray:
"""Handle spoiler gradient: assume complete spoiling."""

_dur = block.block_duration
Expand All @@ -256,7 +255,7 @@ def _handle_spoiler_gradient(self, block, mag):

return mag

def _handle_delay_or_gradient(self, block, mag):
def _handle_delay_or_gradient(self, block: SimpleNamespace, mag: np.ndarray) -> np.ndarray:
"""Handle delay or gradient(s): simulate delay."""

_dur = block.block_duration
Expand All @@ -279,12 +278,7 @@ def get_zspec(self, return_abs: bool = True) -> tuple[np.ndarray, np.ndarray]:
Tuple of offsets and Z-spectrum
"""

if self.run_m0_scan:
m_0 = self.m_out[self.params.mz_loc, 0]
m_ = self.m_out[self.params.mz_loc, 1:]
m_z = m_ / m_0
else:
m_z = self.m_out[self.params.mz_loc, :]
m_z = self.m_out[self.params.mz_loc, :]

if self.offsets_ppm.size != m_z.size:
self.offsets_ppm = np.arange(0, m_z.size)
Expand Down
28 changes: 16 additions & 12 deletions src/bmctool/BlochMcConnellSolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
import math

import numpy as np
Expand All @@ -23,7 +22,6 @@ def __init__(self, params: Parameters, n_offsets: int) -> None:
self.params: Parameters = params
self.n_offsets: int = n_offsets
self.n_pools: int = params.num_cest_pools
self.is_mt_active = bool(params.mt_pool)
self.size: int = params.m_vec.size
self.arr_a: np.ndarray
self.arr_c: np.ndarray
Expand All @@ -42,7 +40,7 @@ def _init_matrix_a(self) -> None:

# Set mt_pool parameters
k_ac = 0.0
if self.is_mt_active:
if self.params.mt_pool is not None:
k_ca = self.params.mt_pool.k
k_ac = k_ca * self.params.mt_pool.f
self.arr_a[0, 2 * (n_p + 1), 3 * (n_p + 1)] = k_ca
Expand Down Expand Up @@ -95,8 +93,8 @@ def _init_vector_c(self) -> None:
pool.f * pool.r1 for pool in self.params.cest_pools
]

if self.is_mt_active:
# Set mt_pool parameters
# Set mt_pool parameters
if self.params.mt_pool is not None:
self.arr_c[0, 3 * (n_p + 1), 0] = self.params.mt_pool.f * self.params.mt_pool.r1

def update_params(self, params: Parameters) -> None:
Expand All @@ -114,7 +112,7 @@ def update_params(self, params: Parameters) -> None:
self._init_matrix_a()
self._init_vector_c()

def update_matrix(self, rf_amp: float, rf_phase: np.ndarray, rf_freq: np.ndarray) -> None:
def update_matrix(self, rf_amp: float, rf_phase: float, rf_freq: float) -> None:
"""Update matrix self.A according to given parameters.
Parameters
Expand Down Expand Up @@ -165,7 +163,7 @@ def update_matrix(self, rf_amp: float, rf_phase: np.ndarray, rf_freq: np.ndarray
self.arr_a[:, indices + n_p + 1, indices] = dwi_values

# mt_pool
if self.is_mt_active:
if self.params.mt_pool is not None:
self.arr_a[:, 3 * (n_p + 1), 3 * (n_p + 1)] = (
-self.params.mt_pool.r1
- self.params.mt_pool.k
Expand Down Expand Up @@ -257,23 +255,26 @@ def _solve_expm(matrix: np.ndarray, dtp: float) -> np.ndarray:
inv = np.linalg.inv(vects)
return np.einsum('ijk,ikl->ijl', tmp, inv)

def get_mt_shape_at_offset(self, offsets: np.ndarray, w0: float) -> np.ndarray:
def get_mt_shape_at_offset(self, offset: float, w0: float) -> float:
"""Calculate the lineshape of the MT pool at the given offset(s).
:param offsets: frequency offset(s)
:param w0: Larmor frequency of simulated system
:return: lineshape of mt pool at given offset(s)
"""
if not self.params.mt_pool:
return 0

ls = self.params.mt_pool.lineshape.lower()
dw = self.params.mt_pool.dw
t2 = 1 / self.params.mt_pool.r2
if ls == 'lorentzian':
mt_line = t2 / (1 + pow((offsets - dw * w0) * t2, 2.0))
mt_line = t2 / (1 + pow((offset - dw * w0) * t2, 2.0))
elif ls == 'superlorentzian':
dw_pool = offsets - dw * w0
dw_pool = offset - dw * w0
mt_line = self.interpolate_sl(dw_pool) if abs(dw_pool) >= w0 else self.interpolate_chs(dw_pool, w0)
else:
mt_line = np.zeros(offsets.size)
mt_line = 0
return mt_line

def interpolate_sl(self, dw: float) -> float:
Expand All @@ -282,6 +283,9 @@ def interpolate_sl(self, dw: float) -> float:
:param dw: relative frequency offset
:return: MT profile at given relative frequency offset
"""
if not self.params.mt_pool:
return 0

mt_line = 0
t2 = 1 / self.params.mt_pool.r2
n_samples = 101
Expand All @@ -292,7 +296,7 @@ def interpolate_sl(self, dw: float) -> float:
mt_line += sqrt_2pi * t2 / powcu2 * np.exp(-2 * pow(dw * t2 / powcu2, 2))
return mt_line * np.pi * step_size

def interpolate_chs(self, dw_pool: float, w0: float) -> np.ndarray:
def interpolate_chs(self, dw_pool: float, w0: float) -> float:
"""Cubic Hermite Spline Interpolation."""
mt_line = 0
px = np.array([-300 - w0, -100 - w0, 100 + w0, 300 + w0])
Expand Down
2 changes: 1 addition & 1 deletion src/bmctool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ['Parameters', 'BMCTool', 'GAMMA_HZ']

from bmctool.parameters import Parameters
from bmctool.BMCTool import BMCTool
from bmctool.parameters import Parameters

GAMMA_HZ = 42.5764
19 changes: 8 additions & 11 deletions src/bmctool/parameters/_Parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ class Parameters:
Additional options
"""

water_pool: WaterPool = dataclasses.field(default_factory=WaterPool)
cest_pools: list = dataclasses.field(default_factory=list)
mt_pool: MTPool = dataclasses.field(default_factory=MTPool)
system: System = dataclasses.field(default_factory=System)
options: Options = dataclasses.field(default_factory=Options)
water_pool: WaterPool
cest_pools: list[CESTPool]
mt_pool: MTPool | None
system: System
options: Options

def __eq__(self, other):
if isinstance(other, Parameters):
Expand Down Expand Up @@ -100,12 +100,9 @@ def from_dict(cls, config: dict) -> Parameters:
}

#
sys_keys = [
attr for attr in System.__dict__ if not callable(getattr(System, attr)) and not attr.startswith('_')
]
opt_keys = [
attr for attr in Options.__dict__ if not callable(getattr(Options, attr)) and not attr.startswith('_')
]
sys_keys = [attr.lstrip('_') for attr in System.__slots__]

opt_keys = [attr.lstrip('_') for attr in Options.__slots__]

water_pool = WaterPool(**config['water_pool'])
cest_pools = [CESTPool(**pool) for pool in config.get('cest_pool', {}).values()]
Expand Down

0 comments on commit 460b8ce

Please sign in to comment.