-
Notifications
You must be signed in to change notification settings - Fork 1
/
20_benchmark_quality.py
176 lines (151 loc) · 5.43 KB
/
20_benchmark_quality.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Benchmark for quality of reconstruction.
This script benchmarks the quality of MRI reconstruction using various algorithms and trajectories.
It requires PySAP-MRI to be installed and uses several libraries to measure performance and quality metrics such as SSIM and SNR.
Usage:
python benchmark.py --config-name ismrm2024
Output:
Saves reconstructed images and quality metrics in JSON format.
"""
import json
import logging
import os
from pathlib import Path
import hydra
import numpy as np
from hydra_callbacks.logger import PerfLogger
from hydra_callbacks.monitor import ResourceMonitorService
from modopt.math.metrics import snr, ssim
from modopt.opt.linear import Identity
from mrinufft import get_operator
from mrinufft.density import voronoi
from mrinufft.io import read_trajectory
from mri.operators.proximity import AutoWeightedSparseThreshold
from solver_utils import get_grad_op, OPTIMIZERS, initialize_opt, WaveletTransform
# Initialize logger
logger = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="qual", config_name="ismrm2024")
def main(cfg):
"""Run benchmark of iterative reconstruction."""
# Read and preprocess trajectory data
traj, params = read_trajectory(str(Path(__file__).parent / (cfg.trajectory.file)))
traj = np.float32(traj)
shape = tuple(params["img_size"])
ref_data = np.load(Path(__file__).parent / cfg.ref_data)
# Check for shape consistency
if ref_data.shape != shape:
raise ValueError("shape mismatch between reference data and trajectory.")
traj_base = Path(cfg.trajectory.file).stem
cache_dir = Path(__file__).parent / cfg.cache_dir
ksp_file = cache_dir / f"{traj_base}_ksp.npy"
try:
# Try to load cached k-space data
ksp_data = np.load(ksp_file)
except FileNotFoundError:
# Generate the kspace data with high precision nufft and cache it.
finufft = get_operator(
"finufft",
traj,
params["img_size"],
n_coils=1,
smaps=None,
density=False,
eps=6e-8,
)
ksp_data = finufft.op(ref_data)
if getattr(cfg, "cache_dir", None):
os.makedirs(cache_dir, exist_ok=True)
np.save(ksp_file, ksp_data)
# Load density weights if density is voronoi
if cfg.trajectory.density == "voronoi":
density = voronoi(traj)
else:
density = cfg.trajectory.density
# Initialize the Fourier Operator to benchmark (n_coils = 1)
fourier_op = get_operator(
cfg.backend.name,
traj,
shape,
n_coils=1,
smaps=None,
density=density,
eps=cfg.backend.eps,
upsampfac=cfg.backend.upsampfac,
squeeze_dims=True,
)
# Setup linear operator and regularizer
linear_op = WaveletTransform(
wavelet_name=cfg.solver.wavelet.base,
shape=shape,
level=cfg.solver.wavelet.nb_scale,
n_coils=1,
mode="periodization",
)
regularizer_op = AutoWeightedSparseThreshold(
linear_op.coeffs_shape,
linear=Identity(),
update_period=0, # the weight is updated only once.
sigma_range="global",
thresh_range="global",
threshold_estimation="sure",
thresh_type="soft",
)
# Setup gradient operator and solver
grad_op = get_grad_op(
fourier_op,
OPTIMIZERS[cfg.solver.optimizer],
linear_op,
)
grad_op._obs_data = ksp_data
solver = initialize_opt(
cfg.solver.optimizer,
grad_op,
linear_op,
regularizer_op,
opt_kwargs={"cost": None, "progress": True},
)
logger.info(f"Grad inv spec rad {grad_op.inv_spec_rad}")
backend_sig = f"{cfg.backend.name}_{cfg.backend.eps:.0e}_{cfg.backend.upsampfac}"
# Start Reconstruction process
with (
ResourceMonitorService(
interval=cfg.monitor.interval, gpu_monit=cfg.monitor.gpu
) as monit,
PerfLogger(logger, name=backend_sig) as perflog,
):
solver.iterate(max_iter=cfg.solver.max_iter)
if OPTIMIZERS[cfg.solver.optimizer] == "synthesis":
x_final = linear_op.adj_op(solver.x_final)
else:
x_final = solver.x_final
image_rec = np.abs(x_final)
# Calculate SSIM and SNR (quality metrics)
recon_ssim = ssim(image_rec, ref_data)
recon_snr = snr(image_rec, ref_data)
# Save the reconstructed image
np.save(f"recon_{backend_sig}_{traj_base}.npy", image_rec)
logger.info(f"{backend_sig}")
logger.info(f"SSIM, SNR: {recon_ssim}, {recon_snr}")
results = {
"backend": cfg.backend.name,
"trajectory": traj_base,
"eps": cfg.backend.eps,
"upsampfac": cfg.backend.upsampfac,
"end_snr": recon_snr,
"end_ssim": recon_ssim,
"image_rec": f"recon_{backend_sig}_{traj_base}.npy",
}
monit_values = monit.get_values()
# Collect resource monitoring data
results["mem_peak"] = np.max(monit_values["rss_GiB"])
results["run_time"] = perflog.get_timer(backend_sig)
if cfg.monitor.gpu:
gpu_keys = [k for k in monit_values.keys() if "gpu" in k]
for k in gpu_keys:
results[f"{k}_avg"] = np.mean(monit_values[k])
results[f"{k}_peak"] = np.max(monit_values[k])
# Save the results to a JSON file
with open(f"results_{backend_sig}.json", "w") as f:
json.dump(results, f)
if __name__ == "__main__":
main()