diff --git a/.gitignore b/.gitignore index 44059f70..6f72b217 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ site/ # Tests test/test_protein_L/tmp/ ## pytest -.pytest_cache \ No newline at end of file +.pytest_cache +.coverage diff --git a/Makefile b/Makefile index 3a87d871..431618da 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ .PHONY: coverage coverage: - coverage run -m pytest test/test_core.py test/test_main.py test/test_fit.py test/test_cli.py + #coverage run -m pytest test/test_core.py test/test_main.py test/test_fit.py test/test_cli.py + coverage run -m pytest test/test_fitting.py test/test_lineshapes.py test/test_io.py test/test_utils.py test/test_main.py test/test_cli.py coverage-html: coverage html diff --git a/peakipy/cli/check_panel.py b/peakipy/cli/check_panel.py index 44bc00d4..2b11d131 100644 --- a/peakipy/cli/check_panel.py +++ b/peakipy/cli/check_panel.py @@ -69,7 +69,7 @@ def create_plotly_pane(cluster, plane): data_path=data.data_path, clusters=[cluster], plane=[plane], - config_path=data.config_path, + # config_path=data.config_path, plotly=True, ) diff --git a/peakipy/cli/edit.py b/peakipy/cli/edit.py index 080f5b02..3f4d31ef 100644 --- a/peakipy/cli/edit.py +++ b/peakipy/cli/edit.py @@ -34,7 +34,8 @@ from bokeh.plotting.contour import contour_data from bokeh.palettes import PuBuGn9, Category20, Viridis256, RdGy11, Reds256, YlOrRd9 -from peakipy.core import LoadData, update_args_with_values_from_config_file, StrucEl +from peakipy.io import LoadData, StrucEl +from peakipy.utils import update_args_with_values_from_config_file log_style = "overflow:scroll;" log_div = """
%s
""" diff --git a/peakipy/cli/edit_panel.py b/peakipy/cli/edit_panel.py index 04b853d0..f74f5b3e 100644 --- a/peakipy/cli/edit_panel.py +++ b/peakipy/cli/edit_panel.py @@ -102,9 +102,12 @@ def fit_peaks_button_click(event): button.on_click(fit_peaks_button_click) def update_source_selected_indices(event): - # print(event) # print(bs.tablulator_widget.selection) + # hack to make current selection however, only allows one selection + # at a time + bs.tablulator_widget._update_selection([event.value]) bs.source.selected.indices = bs.tablulator_widget.selection + # print(bs.tablulator_widget.selection) bs.tablulator_widget.on_click(update_source_selected_indices) bs.tablulator_widget.on_edit(update_peakipy_data_on_edit_of_table) diff --git a/peakipy/cli/fit.py b/peakipy/cli/fit.py index 7beb8f0c..8ffc73e1 100644 --- a/peakipy/cli/fit.py +++ b/peakipy/cli/fit.py @@ -14,15 +14,17 @@ from lmfit import Model, Parameter, Parameters from lmfit.model import ModelResult -from peakipy.core import ( - fix_params, +from peakipy.lineshapes import ( Lineshape, pvoigt2d, voigt2d, pv_pv, + get_lineshape_function, +) +from peakipy.fitting import ( + fix_params, to_prefix, get_limits_for_axis_in_points, - get_lineshape_function, deal_with_peaks_on_edge_of_spectrum, select_planes_above_threshold_from_masked_data, select_reference_planes_using_indices, diff --git a/peakipy/cli/main.py b/peakipy/cli/main.py index 3a636632..79613f42 100644 --- a/peakipy/cli/main.py +++ b/peakipy/cli/main.py @@ -1,23 +1,4 @@ #!/usr/bin/env python3 -""" - - peakipy - deconvolute overlapping NMR peaks - Copyright (C) 2019 Jacob Peter Brady - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - -""" import os import json import shutil @@ -42,77 +23,75 @@ from mpl_toolkits.mplot3d import axes3d from matplotlib import cm from matplotlib.backends.backend_pdf import PdfPages -from matplotlib.widgets import Button import yaml -import plotly.graph_objects as go import plotly.io as pio pio.templates.default = "plotly_dark" -from peakipy.core import ( +from peakipy.io import ( Peaklist, - run_log, LoadData, - pv_pv, - pvoigt2d, - voigt2d, - make_mask, - pv_g, - pv_l, - gaussian_lorentzian, Pseudo3D, - df_to_rich_table, StrucEl, PeaklistFormat, - Lineshape, OutFmt, + get_vclist, +) +from peakipy.utils import ( + run_log, + df_to_rich_table, write_config, update_config_file, update_args_with_values_from_config_file, - get_limits_for_axis_in_points, - deal_with_peaks_on_edge_of_spectrum, - calculate_fwhm_for_voigt_lineshape, - calculate_height_for_voigt_lineshape, - calculate_fwhm_for_pseudo_voigt_lineshape, - calculate_height_for_pseudo_voigt_lineshape, - calculate_height_for_gaussian_lineshape, - calculate_height_for_lorentzian_lineshape, - calculate_height_for_pv_pv_lineshape, + update_linewidths_from_hz_to_points, + update_peak_positions_from_ppm_to_points, + check_data_shape_is_consistent_with_dims, + check_for_include_column_and_add_if_missing, + remove_excluded_peaks, + warn_if_trying_to_fit_large_clusters, + save_data, +) + +from peakipy.lineshapes import ( + Lineshape, + calculate_lineshape_specific_height_and_fwhm, calculate_peak_centers_in_ppm, calculate_peak_linewidths_in_hz, ) +from peakipy.fitting import ( + get_limits_for_axis_in_points, + deal_with_peaks_on_edge_of_spectrum, + select_specified_planes, + exclude_specified_planes, + unpack_xy_bounds, + validate_plane_selection, + get_fit_data_for_selected_peak_clusters, + make_masks_from_plane_data, + simulate_lineshapes_from_fitted_peak_parameters, + simulate_pv_pv_lineshapes_from_fitted_peak_parameters, + validate_fit_dataframe, +) + from .fit import ( fit_peak_clusters, FitPeaksInput, FitPeaksArgs, ) +from peakipy.plotting import ( + PlottingDataForPlane, + validate_sample_count, + unpack_plotting_colors, + create_plotly_figure, + create_residual_figure, + create_matplotlib_figure, +) from .spec import yaml_file app = typer.Typer() tmp_path = Path("tmp") tmp_path.mkdir(exist_ok=True) log_path = Path("log.txt") -# for printing dataframes -peaklist_columns_for_printing = ["INDEX", "ASS", "X_PPM", "Y_PPM", "CLUSTID", "MEMCNT"] -bad_column_selection = [ - "clustid", - "amp", - "center_x_ppm", - "center_y_ppm", - "fwhm_x_hz", - "fwhm_y_hz", - "lineshape", -] -bad_color_selection = [ - "green", - "blue", - "yellow", - "red", - "yellow", - "red", - "magenta", -] peaklist_path_help = "Path to peaklist" @@ -367,225 +346,46 @@ def read( ) -def calculate_lineshape_specific_height_and_fwhm( - lineshape: Lineshape, df: pd.DataFrame -): - match lineshape: - case lineshape.V: - df = calculate_height_for_voigt_lineshape(df) - df = calculate_fwhm_for_voigt_lineshape(df) - - case lineshape.PV: - df = calculate_height_for_pseudo_voigt_lineshape(df) - df = calculate_fwhm_for_pseudo_voigt_lineshape(df) - - case lineshape.G: - df = calculate_height_for_gaussian_lineshape(df) - df = calculate_fwhm_for_pseudo_voigt_lineshape(df) - - case lineshape.L: - df = calculate_height_for_lorentzian_lineshape(df) - df = calculate_fwhm_for_pseudo_voigt_lineshape(df) - - case lineshape.PV_PV: - df = calculate_height_for_pv_pv_lineshape(df) - df = calculate_fwhm_for_pseudo_voigt_lineshape(df) - case _: - df = calculate_fwhm_for_pseudo_voigt_lineshape(df) - return df - - -def get_vclist(vclist, args): - # read vclist - if vclist is None: - vclist = False - elif vclist.exists(): - vclist_data = np.genfromtxt(vclist) - args["vclist_data"] = vclist_data - vclist = True - else: - raise Exception("vclist not found...") - - args["vclist"] = vclist - return args - - -def check_data_shape_is_consistent_with_dims(peakipy_data): - # check data shape is consistent with dims - if len(peakipy_data.dims) != len(peakipy_data.data.shape): - print( - f"Dims are {peakipy_data.dims} while data shape is {peakipy_data.data.shape}?" - ) - exit() - - -def select_specified_planes(plane, peakipy_data): - plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) - # only fit specified planes - if plane: - inds = [i for i in plane] - data_inds = [ - (i in inds) for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) - ] - plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ - data_inds - ] - peakipy_data.data = peakipy_data.data[data_inds] - print( - "[yellow]Using only planes {plane} data now has the following shape[/yellow]", - peakipy_data.data.shape, - ) - if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: - print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) - exit() - return plane_numbers, peakipy_data - - -def exclude_specified_planes(exclude_plane, peakipy_data): - plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) - # do not fit these planes - if exclude_plane: - inds = [i for i in exclude_plane] - data_inds = [ - (i not in inds) - for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) - ] - plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ - data_inds - ] - peakipy_data.data = peakipy_data.data[data_inds] - print( - f"[yellow]Excluding planes {exclude_plane} data now has the following shape[/yellow]", - peakipy_data.data.shape, - ) - if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: - print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) - exit() - return plane_numbers, peakipy_data - - -def check_for_include_column_and_add_if_missing(peakipy_data): - # only include peaks with 'include' - if "include" in peakipy_data.df.columns: - pass - else: - # for compatibility - peakipy_data.df["include"] = peakipy_data.df.apply(lambda _: "yes", axis=1) - return peakipy_data - - -def remove_excluded_peaks(peakipy_data): - if len(peakipy_data.df[peakipy_data.df.include != "yes"]) > 0: - excluded = peakipy_data.df[peakipy_data.df.include != "yes"][ - peaklist_columns_for_printing - ] - table = df_to_rich_table( - excluded, - title="[yellow] Excluded peaks [/yellow]", - columns=excluded.columns, - styles=["yellow" for i in excluded.columns], - ) - print(table) - peakipy_data.df = peakipy_data.df[peakipy_data.df.include == "yes"] - return peakipy_data - - -def warn_if_trying_to_fit_large_clusters(max_cluster_size, peakipy_data): - if max_cluster_size is None: - max_cluster_size = peakipy_data.df.MEMCNT.max() - if peakipy_data.df.MEMCNT.max() > 10: - print( - f"""[red] - ################################################################## - You have some clusters of as many as {max_cluster_size} peaks. - You may want to consider reducing the size of your clusters as the - fits will struggle. - - Otherwise you can use the --max-cluster-size flag to exclude large - clusters - ################################################################## - [/red]""" - ) - else: - max_cluster_size = max_cluster_size - return max_cluster_size - - -def update_linewidths_from_hz_to_points(peakipy_data): - """in case they were adjusted when running edit.py""" - peakipy_data.df["XW"] = peakipy_data.df.XW_HZ * peakipy_data.pt_per_hz_f2 - peakipy_data.df["YW"] = peakipy_data.df.YW_HZ * peakipy_data.pt_per_hz_f1 - return peakipy_data - - -def update_peak_positions_from_ppm_to_points(peakipy_data): - # convert peak positions from ppm to points in case they were adjusted running edit.py - peakipy_data.df["X_AXIS"] = peakipy_data.df.X_PPM.apply( - lambda x: peakipy_data.uc_f2(x, "PPM") - ) - peakipy_data.df["Y_AXIS"] = peakipy_data.df.Y_PPM.apply( - lambda x: peakipy_data.uc_f1(x, "PPM") - ) - peakipy_data.df["X_AXISf"] = peakipy_data.df.X_PPM.apply( - lambda x: peakipy_data.uc_f2.f(x, "PPM") - ) - peakipy_data.df["Y_AXISf"] = peakipy_data.df.Y_PPM.apply( - lambda x: peakipy_data.uc_f1.f(x, "PPM") - ) - return peakipy_data - - -def unpack_xy_bounds(xy_bounds, peakipy_data): - match xy_bounds: - case (0, 0): - xy_bounds = None - case (x, y): - # convert ppm to points - xy_bounds = list(xy_bounds) - xy_bounds[0] = xy_bounds[0] * peakipy_data.pt_per_ppm_f2 - xy_bounds[1] = xy_bounds[1] * peakipy_data.pt_per_ppm_f1 - case _: - raise TypeError( - "xy_bounds should be a tuple (, )" - ) - return xy_bounds - - -def save_data(df, output_name): - suffix = output_name.suffix - if suffix == ".csv": - df.to_csv(output_name, float_format="%.4f", index=False) - - elif suffix == ".tab": - df.to_csv(output_name, sep="\t", float_format="%.4f", index=False) - - else: - df.to_pickle(output_name) - - +fix_help = "Set parameters to fix after initial lineshape fit (see docs)" +xy_bounds_help = ( + "Restrict fitted peak centre within +/- x and y from initial picked position" +) reference_plane_index_help = ( - "Select planes to use for initial estimation of lineshape parameters" + "Select plane(s) to use for initial estimation of lineshape parameters" ) +mp_help = "Use multiprocessing" +vclist_help = "Provide a vclist style file" +plane_help = "Select individual planes for fitting" +exclude_plane_help = "Exclude individual planes from fitting" @app.command(help="Fit NMR data to lineshape models and deconvolute overlapping peaks") def fit( - peaklist_path: Path, - data_path: Path, + peaklist_path: Annotated[Path, typer.Argument(help=peaklist_path_help)], + data_path: Annotated[Path, typer.Argument(help=data_path_help)], output_path: Path, max_cluster_size: Optional[int] = None, lineshape: Lineshape = Lineshape.PV, - fix: List[str] = ["fraction", "sigma", "center"], - xy_bounds: Tuple[float, float] = (0, 0), - vclist: Optional[Path] = None, - plane: Optional[List[int]] = None, - exclude_plane: Optional[List[int]] = None, + fix: Annotated[List[str], typer.Option(help=fix_help)] = [ + "fraction", + "sigma", + "center", + ], + xy_bounds: Annotated[Tuple[float, float], typer.Option(help=xy_bounds_help)] = ( + 0, + 0, + ), + vclist: Annotated[Optional[Path], typer.Option(help=vclist_help)] = None, + plane: Annotated[Optional[List[int]], typer.Option(help=plane_help)] = None, + exclude_plane: Annotated[ + Optional[List[int]], typer.Option(help=exclude_plane_help) + ] = None, reference_plane_index: Annotated[ List[int], typer.Option(help=reference_plane_index_help) ] = [], initial_fit_threshold: Optional[float] = None, jack_knife_sample_errors: bool = False, - mp: bool = True, + mp: Annotated[bool, typer.Option(help=mp_help)] = True, verbose: bool = False, ): """Fit NMR data to lineshape models and deconvolute overlapping peaks @@ -633,7 +433,10 @@ def fit( # read NMR data args = {} config = {} - args, config = update_args_with_values_from_config_file(args) + data_dir = peaklist_path.parent + args, config = update_args_with_values_from_config_file( + args, config_path=data_dir / "peakipy.config" + ) dims = config.get("dims", [0, 1, 2]) peakipy_data = LoadData(peaklist_path, data_path, dims=dims) peakipy_data = check_for_include_column_and_add_if_missing(peakipy_data) @@ -733,551 +536,6 @@ def fit( run_log() -def validate_plane_selection(plane, pseudo3D): - if (plane == []) or (plane == None): - plane = list(range(pseudo3D.n_planes)) - - elif max(plane) > (pseudo3D.n_planes - 1): - raise ValueError( - f"[red]There are {pseudo3D.n_planes} planes in your data you selected --plane {max(plane)}...[red]" - f"plane numbering starts from 0." - ) - elif min(plane) < 0: - raise ValueError( - f"[red]Plane number can not be negative; you selected --plane {min(plane)}...[/red]" - ) - else: - plane = sorted(plane) - - return plane - - -def validate_sample_count(sample_count): - if type(sample_count) == int: - sample_count = sample_count - else: - raise TypeError("Sample count (ccount, rcount) should be an integer") - return sample_count - - -def unpack_plotting_colors(colors): - match colors: - case (data_color, fit_color): - data_color, fit_color = colors - case _: - data_color, fit_color = "green", "blue" - return data_color, fit_color - - -def get_fit_data_for_selected_peak_clusters(fits, clusters): - match clusters: - case None | []: - pass - case _: - # only use these clusters - fits = fits[fits.clustid.isin(clusters)] - if len(fits) < 1: - exit(f"Are you sure clusters {clusters} exist?") - return fits - - -def make_masks_from_plane_data(empty_mask_array, plane_data): - # make masks - individual_masks = [] - for cx, cy, rx, ry, name in zip( - plane_data.center_x, - plane_data.center_y, - plane_data.x_radius, - plane_data.y_radius, - plane_data.assignment, - ): - tmp_mask = make_mask(empty_mask_array, cx, cy, rx, ry) - empty_mask_array += tmp_mask - individual_masks.append(tmp_mask) - filled_mask_array = empty_mask_array - return individual_masks, filled_mask_array - - -def simulate_pv_pv_lineshapes_from_fitted_peak_parameters( - peak_parameters, XY, sim_data, sim_data_singles -): - for amp, c_x, c_y, s_x, s_y, frac_x, frac_y, ls in zip( - peak_parameters.amp, - peak_parameters.center_x, - peak_parameters.center_y, - peak_parameters.sigma_x, - peak_parameters.sigma_y, - peak_parameters.fraction_x, - peak_parameters.fraction_y, - peak_parameters.lineshape, - ): - sim_data_i = pv_pv(XY, amp, c_x, c_y, s_x, s_y, frac_x, frac_y).reshape( - sim_data.shape - ) - sim_data += sim_data_i - sim_data_singles.append(sim_data_i) - return sim_data, sim_data_singles - - -def simulate_lineshapes_from_fitted_peak_parameters( - peak_parameters, XY, sim_data, sim_data_singles -): - shape = sim_data.shape - for amp, c_x, c_y, s_x, s_y, frac, lineshape in zip( - peak_parameters.amp, - peak_parameters.center_x, - peak_parameters.center_y, - peak_parameters.sigma_x, - peak_parameters.sigma_y, - peak_parameters.fraction, - peak_parameters.lineshape, - ): - # print(amp) - match lineshape: - case "G" | "L" | "PV": - sim_data_i = pvoigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) - case "PV_L": - sim_data_i = pv_l(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) - - case "PV_G": - sim_data_i = pv_g(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) - - case "G_L": - sim_data_i = gaussian_lorentzian( - XY, amp, c_x, c_y, s_x, s_y, frac - ).reshape(shape) - - case "V": - sim_data_i = voigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) - sim_data += sim_data_i - sim_data_singles.append(sim_data_i) - return sim_data, sim_data_singles - - -@dataclass -class PlottingDataForPlane: - pseudo3D: Pseudo3D - plane_id: int - plane_lineshape_parameters: pd.DataFrame - X: np.array - Y: np.array - mask: np.array - individual_masks: List[np.array] - sim_data: np.array - sim_data_singles: List[np.array] - min_x: int - max_x: int - min_y: int - max_y: int - fit_color: str - data_color: str - rcount: int - ccount: int - - x_plot: np.array = field(init=False) - y_plot: np.array = field(init=False) - masked_data: np.array = field(init=False) - masked_sim_data: np.array = field(init=False) - residual: np.array = field(init=False) - single_colors: List = field(init=False) - - def __post_init__(self): - self.plane_data = self.pseudo3D.data[self.plane_id] - self.masked_data = self.plane_data.copy() - self.masked_sim_data = self.sim_data.copy() - self.masked_data[~self.mask] = np.nan - self.masked_sim_data[~self.mask] = np.nan - - self.x_plot = self.pseudo3D.uc_f2.ppm( - self.X[self.min_y : self.max_y, self.min_x : self.max_x] - ) - self.y_plot = self.pseudo3D.uc_f1.ppm( - self.Y[self.min_y : self.max_y, self.min_x : self.max_x] - ) - self.masked_data = self.masked_data[ - self.min_y : self.max_y, self.min_x : self.max_x - ] - self.sim_plot = self.masked_sim_data[ - self.min_y : self.max_y, self.min_x : self.max_x - ] - self.residual = self.masked_data - self.sim_plot - - for single_mask, single in zip(self.individual_masks, self.sim_data_singles): - single[~single_mask] = np.nan - self.sim_data_singles = [ - sim_data_single[self.min_y : self.max_y, self.min_x : self.max_x] - for sim_data_single in self.sim_data_singles - ] - self.single_colors = [ - cm.viridis(i) for i in np.linspace(0, 1, len(self.sim_data_singles)) - ] - - -def plot_data_is_valid(plot_data: PlottingDataForPlane) -> bool: - if len(plot_data.x_plot) < 1 or len(plot_data.y_plot) < 1: - print( - f"[red]Nothing to plot for cluster {int(plot_data.plane_lineshape_parameters.clustid)}[/red]" - ) - print(f"[red]x={plot_data.x_plot},y={plot_data.y_plot}[/red]") - print( - df_to_rich_table( - plot_data.plane_lineshape_parameters, - title="", - columns=bad_column_selection, - styles=bad_color_selection, - ) - ) - plt.close() - validated = False - # print(Fore.RED + "Maybe your F1/F2 radii for fitting were too small...") - elif plot_data.masked_data.shape[0] == 0 or plot_data.masked_data.shape[1] == 0: - print(f"[red]Nothing to plot for cluster {int(plot_data.plane.clustid)}[/red]") - print( - df_to_rich_table( - plot_data.plane_lineshape_parameters, - title="Bad plane", - columns=bad_column_selection, - styles=bad_color_selection, - ) - ) - spec_lim_f1 = " - ".join( - ["%8.3f" % i for i in plot_data.pseudo3D.f1_ppm_limits] - ) - spec_lim_f2 = " - ".join( - ["%8.3f" % i for i in plot_data.pseudo3D.f2_ppm_limits] - ) - print(f"Spectrum limits are {plot_data.pseudo3D.f2_label:4s}:{spec_lim_f2} ppm") - print(f" {plot_data.pseudo3D.f1_label:4s}:{spec_lim_f1} ppm") - plt.close() - validated = False - else: - validated = True - return validated - - -def create_matplotlib_figure( - plot_data: PlottingDataForPlane, - pdf: PdfPages, - individual=False, - label=False, - ccpn_flag=False, - show=True, -): - fig = plt.figure(figsize=(10, 6)) - ax = fig.add_subplot(projection="3d") - if plot_data_is_valid(plot_data): - cset = ax.contourf( - plot_data.x_plot, - plot_data.y_plot, - plot_data.residual, - zdir="z", - offset=np.nanmin(plot_data.masked_data) * 1.1, - alpha=0.5, - cmap=cm.coolwarm, - ) - cbl = fig.colorbar(cset, ax=ax, shrink=0.5, format="%.2e") - cbl.ax.set_title("Residual", pad=20) - - if individual: - #  for plotting single fit surfaces - single_colors = [ - cm.viridis(i) - for i in np.linspace(0, 1, len(plot_data.sim_data_singles)) - ] - [ - ax.plot_surface( - plot_data.x_plot, - plot_data.y_plot, - z_single, - color=c, - alpha=0.5, - ) - for c, z_single in zip(single_colors, plot_data.sim_data_singles) - ] - ax.plot_wireframe( - plot_data.x_plot, - plot_data.y_plot, - plot_data.sim_plot, - # colors=[cm.coolwarm(i) for i in np.ravel(residual)], - colors=plot_data.fit_color, - linestyle="--", - label="fit", - rcount=plot_data.rcount, - ccount=plot_data.ccount, - ) - ax.plot_wireframe( - plot_data.x_plot, - plot_data.y_plot, - plot_data.masked_data, - colors=plot_data.data_color, - linestyle="-", - label="data", - rcount=plot_data.rcount, - ccount=plot_data.ccount, - ) - ax.set_ylabel(plot_data.pseudo3D.f1_label) - ax.set_xlabel(plot_data.pseudo3D.f2_label) - - # axes will appear inverted - ax.view_init(30, 120) - - title = f"Plane={plot_data.plane_id},Cluster={plot_data.plane_lineshape_parameters.clustid.iloc[0]}" - plt.title(title) - print(f"[green]Plotting: {title}[/green]") - out_str = "Volumes (Heights)\n===========\n" - for _, row in plot_data.plane_lineshape_parameters.iterrows(): - out_str += f"{row.assignment} = {row.amp:.3e} ({row.height:.3e})\n" - if label: - ax.text( - row.center_x_ppm, - row.center_y_ppm, - row.height * 1.2, - row.assignment, - (1, 1, 1), - ) - - ax.text2D( - -0.5, - 1.0, - out_str, - transform=ax.transAxes, - fontsize=10, - fontfamily="sans-serif", - va="top", - bbox=dict(boxstyle="round", ec="k", fc="k", alpha=0.5), - ) - - ax.legend() - - if show: - - def exit_program(event): - exit() - - def next_plot(event): - plt.close() - - axexit = plt.axes([0.81, 0.05, 0.1, 0.075]) - bnexit = Button(axexit, "Exit") - bnexit.on_clicked(exit_program) - axnext = plt.axes([0.71, 0.05, 0.1, 0.075]) - bnnext = Button(axnext, "Next") - bnnext.on_clicked(next_plot) - if ccpn_flag: - plt.show(windowTitle="", size=(1000, 500)) - else: - plt.show() - else: - pdf.savefig() - - plt.close() - - -def create_plotly_wireframe_lines(plot_data: PlottingDataForPlane): - lines = [] - show_legend = lambda x: x < 1 - showlegend = False - # make simulated data wireframe - line_marker = dict(color=plot_data.fit_color, width=4) - counter = 0 - for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.sim_plot): - showlegend = show_legend(counter) - lines.append( - go.Scatter3d( - x=i, - y=j, - z=k, - mode="lines", - line=line_marker, - name="fit", - showlegend=showlegend, - ) - ) - counter += 1 - for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.sim_plot.T): - lines.append( - go.Scatter3d( - x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend - ) - ) - # make experimental data wireframe - line_marker = dict(color=plot_data.data_color, width=4) - counter = 0 - for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.masked_data): - showlegend = show_legend(counter) - lines.append( - go.Scatter3d( - x=i, - y=j, - z=k, - mode="lines", - name="data", - line=line_marker, - showlegend=showlegend, - ) - ) - counter += 1 - for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.masked_data.T): - lines.append( - go.Scatter3d( - x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend - ) - ) - - return lines - - -def construct_surface_legend_string(row): - surface_legend = "" - surface_legend += row.assignment - return surface_legend - - -def create_plotly_surfaces(plot_data: PlottingDataForPlane): - data = [] - color_scale_values = np.linspace(0, 1, len(plot_data.single_colors)) - color_scale = [ - [val, f"rgb({', '.join('%d'%(i*255) for i in c[0:3])})"] - for val, c in zip(color_scale_values, plot_data.single_colors) - ] - for val, individual_peak, row in zip( - color_scale_values, - plot_data.sim_data_singles, - plot_data.plane_lineshape_parameters.itertuples(), - ): - name = construct_surface_legend_string(row) - colors = np.zeros(shape=individual_peak.shape) + val - data.append( - go.Surface( - z=individual_peak, - x=plot_data.x_plot, - y=plot_data.y_plot, - opacity=0.5, - surfacecolor=colors, - colorscale=color_scale, - showscale=False, - cmin=0, - cmax=1, - name=name, - ) - ) - return data - - -def create_residual_contours(plot_data: PlottingDataForPlane): - contours = go.Contour( - x=plot_data.x_plot[0], y=plot_data.y_plot.T[0], z=plot_data.residual - ) - return contours - - -def create_residual_figure(plot_data: PlottingDataForPlane): - data = create_residual_contours(plot_data) - fig = go.Figure(data=data) - fig.update_layout( - title="Fit residuals", - xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm", - yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm", - xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]), - yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]), - ) - return fig - - -def create_plotly_figure(plot_data: PlottingDataForPlane): - lines = create_plotly_wireframe_lines(plot_data) - surfaces = create_plotly_surfaces(plot_data) - fig = go.Figure(data=lines + surfaces) - fig = update_axis_ranges(fig, plot_data) - return fig - - -def update_axis_ranges(fig, plot_data: PlottingDataForPlane): - fig.update_layout( - scene=dict( - xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]), - yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]), - xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm", - yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm", - annotations=make_annotations(plot_data), - ) - ) - return fig - - -def make_annotations(plot_data: PlottingDataForPlane): - annotations = [] - for row in plot_data.plane_lineshape_parameters.itertuples(): - annotations.append( - dict( - showarrow=True, - x=row.center_x_ppm, - y=row.center_y_ppm, - z=row.height * 1.0, - text=row.assignment, - opacity=0.8, - textangle=0, - arrowsize=1, - ) - ) - return annotations - - -class FitDataModel(BaseModel): - plane: int - clustid: int - assignment: str - memcnt: int - amp: float - height: float - center_x_ppm: float - center_y_ppm: float - fwhm_x_hz: float - fwhm_y_hz: float - lineshape: str - x_radius: float - y_radius: float - center_x: float - center_y: float - sigma_x: float - sigma_y: float - - -class FitDataModelPVGL(FitDataModel): - fraction: float - - -class FitDataModelVoigt(FitDataModel): - fraction: float - gamma_x: float - gamma_y: float - - -class FitDataModelPVPV(FitDataModel): - fraction_x: float - fraction_y: float - - -def validate_fit_data(dict): - lineshape = dict.get("lineshape") - if lineshape in ["PV", "G", "L"]: - fit_data = FitDataModelPVGL(**dict) - elif lineshape == "V": - fit_data = FitDataModelVoigt(**dict) - else: - fit_data = FitDataModelPVPV(**dict) - - return fit_data.model_dump() - - -def validate_fit_dataframe(df): - validated_fit_data = [] - for _, row in df.iterrows(): - fit_data = validate_fit_data(row.to_dict()) - validated_fit_data.append(fit_data) - return pd.DataFrame(validated_fit_data) - - @app.command(help="Interactive plots for checking fits") def check( fits: Path, @@ -1295,7 +553,6 @@ def check( colors: Tuple[str, str] = ("#5e3c99", "#e66101"), verb: bool = False, plotly: bool = False, - config_path: Path = Path("peakipy.config"), ): """Interactive plots for checking fits @@ -1348,6 +605,7 @@ def check( fits = validate_fit_dataframe(pd.read_csv(fits)) args = {} # get dims from config file + config_path = data_path.parent / "peakipy.config" args, config = update_args_with_values_from_config_file(args, config_path) dims = config.get("dims", (1, 2, 3)) diff --git a/peakipy/constants.py b/peakipy/constants.py new file mode 100644 index 00000000..04039817 --- /dev/null +++ b/peakipy/constants.py @@ -0,0 +1,6 @@ +from numpy import log, pi, finfo + + +log2 = log(2) +π = pi +tiny = finfo(float).eps diff --git a/peakipy/fitting.py b/peakipy/fitting.py new file mode 100644 index 00000000..915ebfae --- /dev/null +++ b/peakipy/fitting.py @@ -0,0 +1,645 @@ +from dataclasses import dataclass, field +from typing import List + +import numpy as np +from numpy import sqrt +import pandas as pd +from lmfit import Model +from pydantic import BaseModel + +from peakipy.lineshapes import Lineshape, pvoigt2d, pv_pv, pv_g, pv_l, voigt2d +from peakipy.constants import log2 + + +class FitDataModel(BaseModel): + plane: int + clustid: int + assignment: str + memcnt: int + amp: float + height: float + center_x_ppm: float + center_y_ppm: float + fwhm_x_hz: float + fwhm_y_hz: float + lineshape: str + x_radius: float + y_radius: float + center_x: float + center_y: float + sigma_x: float + sigma_y: float + + +class FitDataModelPVGL(FitDataModel): + fraction: float + + +class FitDataModelVoigt(FitDataModel): + fraction: float + gamma_x: float + gamma_y: float + + +class FitDataModelPVPV(FitDataModel): + fraction_x: float + fraction_y: float + + +def validate_fit_data(dict): + lineshape = dict.get("lineshape") + if lineshape in ["PV", "G", "L"]: + fit_data = FitDataModelPVGL(**dict) + elif lineshape == "V": + fit_data = FitDataModelVoigt(**dict) + else: + fit_data = FitDataModelPVPV(**dict) + + return fit_data.model_dump() + + +def validate_fit_dataframe(df): + validated_fit_data = [] + for _, row in df.iterrows(): + fit_data = validate_fit_data(row.to_dict()) + validated_fit_data.append(fit_data) + return pd.DataFrame(validated_fit_data) + + +def make_mask(data, c_x, c_y, r_x, r_y): + """Create and elliptical mask + + Generate an elliptical boolean mask with center c_x/c_y in points + with radii r_x and r_y. Used to generate fit mask + + :param data: 2D array + :type data: np.array + + :param c_x: x center + :type c_x: float + + :param c_y: y center + :type c_y: float + + :param r_x: radius in x + :type r_x: float + + :param r_y: radius in y + :type r_y: float + + :return: boolean mask of data.shape + :rtype: numpy.array + + """ + a, b = c_y, c_x + n_y, n_x = data.shape + y, x = np.ogrid[-a : n_y - a, -b : n_x - b] + mask = x**2.0 / r_x**2.0 + y**2.0 / r_y**2.0 <= 1.0 + return mask + + +def fix_params(params, to_fix): + """Set parameters to fix + + + :param params: lmfit parameters + :type params: lmfit.Parameters + + :param to_fix: list of parameter name to fix + :type to_fix: list + + :return: updated parameter object + :rtype: lmfit.Parameters + + """ + for k in params: + for p in to_fix: + if p in k: + params[k].vary = False + + return params + + +def get_params(params, name): + ps = [] + ps_err = [] + names = [] + prefixes = [] + for k in params: + if name in k: + ps.append(params[k].value) + ps_err.append(params[k].stderr) + names.append(k) + prefixes.append(k.split(name)[0]) + return ps, ps_err, names, prefixes + + +@dataclass +class PeakLimits: + """Given a peak position and linewidth in points determine + the limits based on the data + + Arguments + --------- + peak: pd.DataFrame + peak is a row from a pandas dataframe + data: np.array + 2D numpy array + """ + + peak: pd.DataFrame + data: np.array + min_x: int = field(init=False) + max_x: int = field(init=False) + min_y: int = field(init=False) + max_y: int = field(init=False) + + def __post_init__(self): + assert self.peak.Y_AXIS <= self.data.shape[0] + assert self.peak.X_AXIS <= self.data.shape[1] + self.max_y = int(np.ceil(self.peak.Y_AXIS + self.peak.YW)) + 1 + if self.max_y > self.data.shape[0]: + self.max_y = self.data.shape[0] + self.max_x = int(np.ceil(self.peak.X_AXIS + self.peak.XW)) + 1 + if self.max_x > self.data.shape[1]: + self.max_x = self.data.shape[1] + + self.min_y = int(self.peak.Y_AXIS - self.peak.YW) + if self.min_y < 0: + self.min_y = 0 + self.min_x = int(self.peak.X_AXIS - self.peak.XW) + if self.min_x < 0: + self.min_x = 0 + + +def estimate_amplitude(peak, data): + assert len(data.shape) == 2 + limits = PeakLimits(peak, data) + amplitude_est = data[limits.min_y : limits.max_y, limits.min_x : limits.max_x].sum() + return amplitude_est + + +def make_param_dict(peaks, data, lineshape: Lineshape = Lineshape.PV): + """Make dict of parameter names using prefix""" + + param_dict = {} + + for _, peak in peaks.iterrows(): + str_form = lambda x: "%s%s" % (to_prefix(peak.ASS), x) + # using exact value of points (i.e decimal) + param_dict[str_form("center_x")] = peak.X_AXISf + param_dict[str_form("center_y")] = peak.Y_AXISf + # estimate peak volume + amplitude_est = estimate_amplitude(peak, data) + param_dict[str_form("amplitude")] = amplitude_est + # sigma linewidth esimate + param_dict[str_form("sigma_x")] = peak.XW / 2.0 + param_dict[str_form("sigma_y")] = peak.YW / 2.0 + + match lineshape: + case lineshape.V: + #  Voigt G sigma from linewidth esimate + param_dict[str_form("sigma_x")] = peak.XW / ( + 2.0 * sqrt(2.0 * log2) + ) # 3.6013 + param_dict[str_form("sigma_y")] = peak.YW / ( + 2.0 * sqrt(2.0 * log2) + ) # 3.6013 + #  Voigt L gamma from linewidth esimate + param_dict[str_form("gamma_x")] = peak.XW / 2.0 + param_dict[str_form("gamma_y")] = peak.YW / 2.0 + # height + # add height here + + case lineshape.G: + param_dict[str_form("fraction")] = 0.0 + case lineshape.L: + param_dict[str_form("fraction")] = 1.0 + case lineshape.PV_PV: + param_dict[str_form("fraction_x")] = 0.5 + param_dict[str_form("fraction_y")] = 0.5 + case _: + param_dict[str_form("fraction")] = 0.5 + + return param_dict + + +def to_prefix(x): + """ + Peak assignments with characters that are not compatible lmfit model naming + are converted to lmfit "safe" names. + + :param x: Peak assignment to be used as prefix for lmfit model + :type x: str + + :returns: lmfit model prefix (_Peak_assignment_) + :rtype: str + + """ + # must be string + if type(x) != str: + x = str(x) + + prefix = "_" + x + to_replace = [ + [".", "_"], + [" ", ""], + ["{", "_"], + ["}", "_"], + ["[", "_"], + ["]", "_"], + ["-", ""], + ["/", "or"], + ["?", "maybe"], + ["\\", ""], + ["(", "_"], + [")", "_"], + ["@", "_at_"], + ] + for p in to_replace: + prefix = prefix.replace(*p) + return prefix + "_" + + +def make_models( + model, + peaks, + data, + lineshape: Lineshape = Lineshape.PV, + xy_bounds=None, +): + """Make composite models for multiple peaks + + :param model: lineshape function + :type model: function + + :param peaks: instance of pandas.df.groupby("CLUSTID") + :type peaks: pandas.df.groupby("CLUSTID") + + :param data: NMR data + :type data: numpy.array + + :param lineshape: lineshape to use for fit (PV/G/L/PV_PV) + :type lineshape: str + + :param xy_bounds: bounds for peak centers (+/-x, +/-y) + :type xy_bounds: tuple + + :return mod: Composite lmfit model containing all peaks + :rtype mod: lmfit.CompositeModel + + :return p_guess: params for composite model with starting values + :rtype p_guess: lmfit.Parameters + + """ + if len(peaks) == 1: + # make model for first peak + mod = Model(model, prefix="%s" % to_prefix(peaks.ASS.iloc[0])) + # add parameters + param_dict = make_param_dict( + peaks, + data, + lineshape=lineshape, + ) + p_guess = mod.make_params(**param_dict) + + elif len(peaks) > 1: + # make model for first peak + first_peak, *remaining_peaks = peaks.iterrows() + mod = Model(model, prefix="%s" % to_prefix(first_peak[1].ASS)) + for _, peak in remaining_peaks: + mod += Model(model, prefix="%s" % to_prefix(peak.ASS)) + + param_dict = make_param_dict( + peaks, + data, + lineshape=lineshape, + ) + p_guess = mod.make_params(**param_dict) + # add Peak params to p_guess + + update_params(p_guess, param_dict, lineshape=lineshape, xy_bounds=xy_bounds) + + return mod, p_guess + + +def update_params( + params, param_dict, lineshape: Lineshape = Lineshape.PV, xy_bounds=None +): + """Update lmfit parameters with values from Peak + + :param params: lmfit parameters + :type params: lmfit.Parameters object + :param param_dict: parameters corresponding to each peak in fit + :type param_dict: dict + :param lineshape: Lineshape (PV, G, L, PV_PV etc.) + :type lineshape: Lineshape + :param xy_bounds: bounds on xy peak positions + :type xy_bounds: tuple + + :returns: None + :rtype: None + + ToDo + -- deal with boundaries + -- currently positions in points + + """ + for k, v in param_dict.items(): + params[k].value = v + # print("update", k, v) + if "center" in k: + if xy_bounds == None: + # no bounds set + pass + else: + if "center_x" in k: + # set x bounds + x_bound = xy_bounds[0] + params[k].min = v - x_bound + params[k].max = v + x_bound + elif "center_y" in k: + # set y bounds + y_bound = xy_bounds[1] + params[k].min = v - y_bound + params[k].max = v + y_bound + # pass + # print( + # "setting limit of %s, min = %.3e, max = %.3e" + # % (k, params[k].min, params[k].max) + # ) + elif "sigma" in k: + params[k].min = 0.0 + params[k].max = 1e4 + + elif "gamma" in k: + params[k].min = 0.0 + params[k].max = 1e4 + # print( + # "setting limit of %s, min = %.3e, max = %.3e" + # % (k, params[k].min, params[k].max) + # ) + elif "fraction" in k: + # fix weighting between 0 and 1 + params[k].min = 0.0 + params[k].max = 1.0 + + #  fix fraction of G or L + match lineshape: + case lineshape.G | lineshape.L: + params[k].vary = False + case lineshape.PV | lineshape.PV_PV: + params[k].vary = True + case _: + pass + + # return params + + +def make_mask_from_peak_cluster(group, data): + mask = np.zeros(data.shape, dtype=bool) + for _, peak in group.iterrows(): + mask += make_mask( + data, peak.X_AXISf, peak.Y_AXISf, peak.X_RADIUS, peak.Y_RADIUS + ) + return mask, peak + + +def select_reference_planes_using_indices(data, indices: List[int]): + n_planes = data.shape[0] + if indices == []: + return data + + max_index = max(indices) + min_index = min(indices) + + if max_index >= n_planes: + raise IndexError( + f"Your data has {n_planes}. You selected plane {max_index} (allowed indices between 0 and {n_planes-1})" + ) + elif min_index < (-1 * n_planes): + raise IndexError( + f"Your data has {n_planes}. You selected plane {min_index} (allowed indices between -{n_planes} and {n_planes-1})" + ) + else: + data = data[indices] + return data + + +def select_planes_above_threshold_from_masked_data(data, threshold=None): + """This function returns planes with data above the threshold. + + It currently uses absolute intensity values. + Negative thresholds just result in return of the orignal data. + + """ + if threshold == None: + selected_data = data + else: + selected_data = data[np.abs(data).max(axis=1) > threshold] + + if selected_data.shape[0] == 0: + selected_data = data + + return selected_data + + +def validate_plane_selection(plane, pseudo3D): + if (plane == []) or (plane == None): + plane = list(range(pseudo3D.n_planes)) + + elif max(plane) > (pseudo3D.n_planes - 1): + raise ValueError( + f"[red]There are {pseudo3D.n_planes} planes in your data you selected --plane {max(plane)}...[red]" + f"plane numbering starts from 0." + ) + elif min(plane) < 0: + raise ValueError( + f"[red]Plane number can not be negative; you selected --plane {min(plane)}...[/red]" + ) + else: + plane = sorted(plane) + + return plane + + +def slice_peaks_from_data_using_mask(data, mask): + peak_slices = np.array([d[mask] for d in data]) + return peak_slices + + +def get_limits_for_axis_in_points(group_axis_points, mask_radius_in_points): + max_point, min_point = ( + int(np.ceil(max(group_axis_points) + mask_radius_in_points + 1)), + int(np.floor(min(group_axis_points) - mask_radius_in_points)), + ) + return max_point, min_point + + +def deal_with_peaks_on_edge_of_spectrum(data_shape, max_x, min_x, max_y, min_y): + if min_y < 0: + min_y = 0 + + if min_x < 0: + min_x = 0 + + if max_y > data_shape[-2]: + max_y = data_shape[-2] + + if max_x > data_shape[-1]: + max_x = data_shape[-1] + return max_x, min_x, max_y, min_y + + +def make_meshgrid(data_shape): + # must be a better way to make the meshgrid + x = np.arange(data_shape[-1]) + y = np.arange(data_shape[-2]) + XY = np.meshgrid(x, y) + return XY + + +def unpack_xy_bounds(xy_bounds, peakipy_data): + match xy_bounds: + case (0, 0): + xy_bounds = None + case (x, y): + # convert ppm to points + xy_bounds = list(xy_bounds) + xy_bounds[0] = xy_bounds[0] * peakipy_data.pt_per_ppm_f2 + xy_bounds[1] = xy_bounds[1] * peakipy_data.pt_per_ppm_f1 + case _: + raise TypeError( + "xy_bounds should be a tuple (, )" + ) + return xy_bounds + + +def select_specified_planes(plane, peakipy_data): + plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) + # only fit specified planes + if plane: + inds = [i for i in plane] + data_inds = [ + (i in inds) for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) + ] + plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ + data_inds + ] + peakipy_data.data = peakipy_data.data[data_inds] + print( + "[yellow]Using only planes {plane} data now has the following shape[/yellow]", + peakipy_data.data.shape, + ) + if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: + print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) + exit() + return plane_numbers, peakipy_data + + +def exclude_specified_planes(exclude_plane, peakipy_data): + plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) + # do not fit these planes + if exclude_plane: + inds = [i for i in exclude_plane] + data_inds = [ + (i not in inds) + for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) + ] + plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ + data_inds + ] + peakipy_data.data = peakipy_data.data[data_inds] + print( + f"[yellow]Excluding planes {exclude_plane} data now has the following shape[/yellow]", + peakipy_data.data.shape, + ) + if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: + print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) + exit() + return plane_numbers, peakipy_data + + +def get_fit_data_for_selected_peak_clusters(fits, clusters): + match clusters: + case None | []: + pass + case _: + # only use these clusters + fits = fits[fits.clustid.isin(clusters)] + if len(fits) < 1: + exit(f"Are you sure clusters {clusters} exist?") + return fits + + +def make_masks_from_plane_data(empty_mask_array, plane_data): + # make masks + individual_masks = [] + for cx, cy, rx, ry, name in zip( + plane_data.center_x, + plane_data.center_y, + plane_data.x_radius, + plane_data.y_radius, + plane_data.assignment, + ): + tmp_mask = make_mask(empty_mask_array, cx, cy, rx, ry) + empty_mask_array += tmp_mask + individual_masks.append(tmp_mask) + filled_mask_array = empty_mask_array + return individual_masks, filled_mask_array + + +def simulate_pv_pv_lineshapes_from_fitted_peak_parameters( + peak_parameters, XY, sim_data, sim_data_singles +): + for amp, c_x, c_y, s_x, s_y, frac_x, frac_y, ls in zip( + peak_parameters.amp, + peak_parameters.center_x, + peak_parameters.center_y, + peak_parameters.sigma_x, + peak_parameters.sigma_y, + peak_parameters.fraction_x, + peak_parameters.fraction_y, + peak_parameters.lineshape, + ): + sim_data_i = pv_pv(XY, amp, c_x, c_y, s_x, s_y, frac_x, frac_y).reshape( + sim_data.shape + ) + sim_data += sim_data_i + sim_data_singles.append(sim_data_i) + return sim_data, sim_data_singles + + +def simulate_lineshapes_from_fitted_peak_parameters( + peak_parameters, XY, sim_data, sim_data_singles +): + shape = sim_data.shape + for amp, c_x, c_y, s_x, s_y, frac, lineshape in zip( + peak_parameters.amp, + peak_parameters.center_x, + peak_parameters.center_y, + peak_parameters.sigma_x, + peak_parameters.sigma_y, + peak_parameters.fraction, + peak_parameters.lineshape, + ): + # print(amp) + match lineshape: + case "G" | "L" | "PV": + sim_data_i = pvoigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) + case "PV_L": + sim_data_i = pv_l(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) + + case "PV_G": + sim_data_i = pv_g(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) + + case "G_L": + sim_data_i = gaussian_lorentzian( + XY, amp, c_x, c_y, s_x, s_y, frac + ).reshape(shape) + + case "V": + sim_data_i = voigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) + sim_data += sim_data_i + sim_data_singles.append(sim_data_i) + return sim_data, sim_data_singles diff --git a/peakipy/io.py b/peakipy/io.py new file mode 100644 index 00000000..88ed5be8 --- /dev/null +++ b/peakipy/io.py @@ -0,0 +1,913 @@ +import sys +from pathlib import Path +from enum import Enum + +import numpy as np +import nmrglue as ng +import pandas as pd +import textwrap +from rich import print +from rich.console import Console + + +from bokeh.palettes import Category20 +from scipy import ndimage +from skimage.morphology import square, binary_closing, disk, rectangle +from skimage.filters import threshold_otsu + +from peakipy.utils import df_to_rich_table +from peakipy.fitting import make_mask + +console = Console() + + +class StrucEl(str, Enum): + square = "square" + disk = "disk" + rectangle = "rectangle" + mask_method = "mask_method" + + +class PeaklistFormat(str, Enum): + a2 = "a2" + a3 = "a3" + sparky = "sparky" + pipe = "pipe" + peakipy = "peakipy" + + +class OutFmt(str, Enum): + csv = "csv" + pkl = "pkl" + + +class Pseudo3D: + """Read dic, data from NMRGlue and dims from input to create a Pseudo3D dataset + + :param dic: from nmrglue.pipe.read + :type dic: dict + + :param data: data from nmrglue.pipe.read + :type data: numpy.array + + :param dims: dimension order i.e [0,1,2] where 0 = planes, 1 = f1, 2 = f2 + :type dims: list + """ + + def __init__(self, dic, data, dims): + # check dimensions + self._udic = ng.pipe.guess_udic(dic, data) + self._ndim = self._udic["ndim"] + + if self._ndim == 1: + err = f"""[red] + ########################################## + NMR Data should be either 2D or 3D + ########################################## + [/red]""" + # raise TypeError(err) + sys.exit(err) + + # check that spectrum has correct number of dims + elif self._ndim != len(dims): + err = f"""[red] + ################################################################# + Your spectrum has {self._ndim} dimensions with shape {data.shape} + but you have given a dimension order of {dims}... + ################################################################# + [/red]""" + # raise ValueError(err) + sys.exit(err) + + elif (self._ndim == 2) and (len(dims) == 2): + self._f1_dim, self._f2_dim = dims + self._planes = 0 + self._uc_f1 = ng.pipe.make_uc(dic, data, dim=self._f1_dim) + self._uc_f2 = ng.pipe.make_uc(dic, data, dim=self._f2_dim) + # make data pseudo3d + self._data = data.reshape((1, data.shape[0], data.shape[1])) + self._dims = [self._planes, self._f1_dim + 1, self._f2_dim + 1] + + else: + self._planes, self._f1_dim, self._f2_dim = dims + self._dims = dims + self._data = data + # make unit conversion dicts + self._uc_f2 = ng.pipe.make_uc(dic, data, dim=self._f2_dim) + self._uc_f1 = ng.pipe.make_uc(dic, data, dim=self._f1_dim) + + #  rearrange data if dims not in standard order + if self._dims != [0, 1, 2]: + # np.argsort returns indices of array for order 0,1,2 to transpose data correctly + # self._dims = np.argsort(self._dims) + self._data = np.transpose(data, self._dims) + + self._dic = dic + + self._f1_label = self._udic[self._f1_dim]["label"] + self._f2_label = self._udic[self._f2_dim]["label"] + + @property + def uc_f1(self): + """Return unit conversion dict for F1""" + return self._uc_f1 + + @property + def uc_f2(self): + """Return unit conversion dict for F2""" + return self._uc_f2 + + @property + def dims(self): + """Return dimension order""" + return self._dims + + @property + def data(self): + """Return array containing data""" + return self._data + + @data.setter + def data(self, data): + self._data = data + + @property + def dic(self): + return self._dic + + @property + def udic(self): + return self._udic + + @property + def ndim(self): + return self._ndim + + @property + def f1_label(self): + # dim label + return self._f1_label + + @property + def f2_label(self): + # dim label + return self._f2_label + + @property + def planes(self): + return self.dims[0] + + @property + def n_planes(self): + return self.data.shape[self.planes] + + @property + def f1(self): + return self.dims[1] + + @property + def f2(self): + return self.dims[2] + + # size of f1 and f2 in points + @property + def f2_size(self): + """Return size of f2 dimension in points""" + return self._udic[self._f2_dim]["size"] + + @property + def f1_size(self): + """Return size of f1 dimension in points""" + return self._udic[self._f1_dim]["size"] + + # points per ppm + @property + def pt_per_ppm_f1(self): + return self.f1_size / ( + self._udic[self._f1_dim]["sw"] / self._udic[self._f1_dim]["obs"] + ) + + @property + def pt_per_ppm_f2(self): + return self.f2_size / ( + self._udic[self._f2_dim]["sw"] / self._udic[self._f2_dim]["obs"] + ) + + # points per hz + @property + def pt_per_hz_f1(self): + return self.f1_size / self._udic[self._f1_dim]["sw"] + + @property + def pt_per_hz_f2(self): + return self.f2_size / self._udic[self._f2_dim]["sw"] + + # hz per point + @property + def hz_per_pt_f1(self): + return 1.0 / self.pt_per_hz_f1 + + @property + def hz_per_pt_f2(self): + return 1.0 / self.pt_per_hz_f2 + + # ppm per point + @property + def ppm_per_pt_f1(self): + return 1.0 / self.pt_per_ppm_f1 + + @property + def ppm_per_pt_f2(self): + return 1.0 / self.pt_per_ppm_f2 + + # get ppm limits for ppm scales + @property + def f2_ppm_scale(self): + return self.uc_f2.ppm_scale() + + @property + def f1_ppm_scale(self): + return self.uc_f1.ppm_scale() + + @property + def f2_ppm_limits(self): + return self.uc_f2.ppm_limits() + + @property + def f1_ppm_limits(self): + return self.uc_f1.ppm_limits() + + @property + def f1_ppm_max(self): + return max(self.f1_ppm_limits) + + @property + def f1_ppm_min(self): + return min(self.f1_ppm_limits) + + @property + def f2_ppm_max(self): + return max(self.f2_ppm_limits) + + @property + def f2_ppm_min(self): + return min(self.f2_ppm_limits) + + @property + def f2_ppm_0(self): + return self.f2_ppm_limits[0] + + @property + def f2_ppm_1(self): + return self.f2_ppm_limits[1] + + @property + def f1_ppm_0(self): + return self.f1_ppm_limits[0] + + @property + def f1_ppm_1(self): + return self.f1_ppm_limits[1] + + +class UnknownFormat(Exception): + pass + + +class Peaklist(Pseudo3D): + """Read analysis, sparky or NMRPipe peak list and convert to NMRPipe-ish format also find peak clusters + + Parameters + ---------- + path : path-like or str + path to peaklist + data_path : ndarray + NMRPipe format data + fmt : str + a2|a3|sparky|pipe + dims: list + [planes,y,x] + radii: list + [x,y] Mask radii in ppm + + + Methods + ------- + + clusters : + mask_method : + adaptive_clusters : + + Returns + ------- + df : pandas DataFrame + dataframe containing peaklist + + """ + + def __init__( + self, + path, + data_path, + fmt: PeaklistFormat = PeaklistFormat.a2, + dims=[0, 1, 2], + radii=[0.04, 0.4], + posF1="Position F2", + posF2="Position F1", + verbose=False, + ): + dic, data = ng.pipe.read(data_path) + Pseudo3D.__init__(self, dic, data, dims) + self.fmt = fmt + self.peaklist_path = path + self.data_path = data_path + self.verbose = verbose + self._radii = radii + self._thres = None + if self.verbose: + print( + "Points per hz f1 = %.3f, f2 = %.3f" + % (self.pt_per_hz_f1, self.pt_per_hz_f2) + ) + + self._analysis_to_pipe_dic = { + "#": "INDEX", + "Position F1": "X_PPM", + "Position F2": "Y_PPM", + "Line Width F1 (Hz)": "XW_HZ", + "Line Width F2 (Hz)": "YW_HZ", + "Height": "HEIGHT", + "Volume": "VOL", + } + self._assign_to_pipe_dic = { + "#": "INDEX", + "Pos F1": "X_PPM", + "Pos F2": "Y_PPM", + "LW F1 (Hz)": "XW_HZ", + "LW F2 (Hz)": "YW_HZ", + "Height": "HEIGHT", + "Volume": "VOL", + } + + self._sparky_to_pipe_dic = { + "index": "INDEX", + "w1": "X_PPM", + "w2": "Y_PPM", + "lw1 (hz)": "XW_HZ", + "lw2 (hz)": "YW_HZ", + "Height": "HEIGHT", + "Volume": "VOL", + "Assignment": "ASS", + } + + self._analysis_to_pipe_dic[posF1] = "Y_PPM" + self._analysis_to_pipe_dic[posF2] = "X_PPM" + + self._df = self.read_peaklist() + + def read_peaklist(self): + match self.fmt: + case self.fmt.a2: + self._df = self._read_analysis() + + case self.fmt.a3: + self._df = self._read_assign() + + case self.fmt.sparky: + self._df = self._read_sparky() + + case self.fmt.pipe: + self._df = self._read_pipe() + + case _: + raise UnknownFormat("I don't know this format: {self.fmt}") + + return self._df + + @property + def df(self): + return self._df + + @df.setter + def df(self, df): + self._df = df + return self._df + + @property + def radii(self): + return self._radii + + @property + def f2_radius(self): + """radius for fitting mask in f2""" + return self.radii[0] + + @property + def f1_radius(self): + """radius for fitting mask in f1""" + return self.radii[1] + + @property + def analysis_to_pipe_dic(self): + return self._analysis_to_pipe_dic + + @property + def assign_to_pipe_dic(self): + return self._assign_to_pipe_dic + + @property + def sparky_to_pipe_dic(self): + return self._sparky_to_pipe_dic + + @property + def thres(self): + if self._thres == None: + self._thres = abs(threshold_otsu(self.data[0])) + return self._thres + else: + return self._thres + + def update_df(self): + # int point value + self.df["X_AXIS"] = self.df.X_PPM.apply(lambda x: self.uc_f2(x, "ppm")) + self.df["Y_AXIS"] = self.df.Y_PPM.apply(lambda x: self.uc_f1(x, "ppm")) + # decimal point value + self.df["X_AXISf"] = self.df.X_PPM.apply(lambda x: self.uc_f2.f(x, "ppm")) + self.df["Y_AXISf"] = self.df.Y_PPM.apply(lambda x: self.uc_f1.f(x, "ppm")) + # in case of missing values (should estimate though) + self.df["XW_HZ"] = self.df.XW_HZ.replace("None", "20.0") + self.df["YW_HZ"] = self.df.YW_HZ.replace("None", "20.0") + self.df["XW_HZ"] = self.df.XW_HZ.replace(np.NaN, "20.0") + self.df["YW_HZ"] = self.df.YW_HZ.replace(np.NaN, "20.0") + # convert linewidths to float + self.df["XW_HZ"] = self.df.XW_HZ.apply(lambda x: float(x)) + self.df["YW_HZ"] = self.df.YW_HZ.apply(lambda x: float(x)) + # convert Hz lw to points + self.df["XW"] = self.df.XW_HZ.apply(lambda x: x * self.pt_per_hz_f2) + self.df["YW"] = self.df.YW_HZ.apply(lambda x: x * self.pt_per_hz_f1) + # makes an assignment column from Assign F1 and Assign F2 columns + # in analysis2.x and ccpnmr v3 assign peak lists + if self.fmt in [PeaklistFormat.a2, PeaklistFormat.a3]: + self.df["ASS"] = self.df.apply( + # lambda i: "".join([i["Assign F1"], i["Assign F2"]]), axis=1 + lambda i: f"{i['Assign F1']}_{i['Assign F2']}", + axis=1, + ) + + # make default values for X and Y radii for fit masks + self.df["X_RADIUS_PPM"] = np.zeros(len(self.df)) + self.f2_radius + self.df["Y_RADIUS_PPM"] = np.zeros(len(self.df)) + self.f1_radius + self.df["X_RADIUS"] = self.df.X_RADIUS_PPM.apply( + lambda x: x * self.pt_per_ppm_f2 + ) + self.df["Y_RADIUS"] = self.df.Y_RADIUS_PPM.apply( + lambda x: x * self.pt_per_ppm_f1 + ) + # add include column + if "include" in self.df.columns: + pass + else: + self.df["include"] = self.df.apply(lambda x: "yes", axis=1) + + # check assignments for duplicates + self.check_assignments() + # check that peaks are within the bounds of the data + self.check_peak_bounds() + + def add_fix_bound_columns(self): + """add columns containing parameter bounds (param_upper/param_lower) + and whether or not parameter should be fixed (yes/no) + + For parameter bounding: + + Column names are _upper and _lower for upper and lower bounds respectively. + Values are given as floating point. Value of 0.0 indicates that parameter is unbounded + X/Y positions are given in ppm + Linewidths are given in Hz + + For parameter fixing: + + Column names are _fix. + Values are given as a string 'yes' or 'no' + + """ + pass + + def _read_analysis(self): + df = pd.read_csv(self.peaklist_path, delimiter="\t") + new_columns = [self.analysis_to_pipe_dic.get(i, i) for i in df.columns] + pipe_columns = dict(zip(df.columns, new_columns)) + df = df.rename(index=str, columns=pipe_columns) + + return df + + def _read_assign(self): + df = pd.read_csv(self.peaklist_path, delimiter="\t") + new_columns = [self.assign_to_pipe_dic.get(i, i) for i in df.columns] + pipe_columns = dict(zip(df.columns, new_columns)) + df = df.rename(index=str, columns=pipe_columns) + + return df + + def _read_sparky(self): + df = pd.read_csv( + self.peaklist_path, + skiprows=1, + sep=r"\s+", + names=["ASS", "Y_PPM", "X_PPM", "VOLUME", "HEIGHT", "YW_HZ", "XW_HZ"], + ) + df["INDEX"] = df.index + + return df + + def _read_pipe(self): + to_skip = 0 + with open(self.peaklist_path) as f: + lines = f.readlines() + for line in lines: + if line.startswith("VARS"): + columns = line.strip().split()[1:] + elif line[:5].strip(" ").isdigit(): + break + else: + to_skip += 1 + df = pd.read_csv( + self.peaklist_path, skiprows=to_skip, names=columns, sep=r"\s+" + ) + return df + + def check_assignments(self): + # self.df["ASS"] = self.df. + self.df["ASS"] = self.df.ASS.astype(object) + self.df.loc[self.df["ASS"].isnull(), "ASS"] = "None_dummy_0" + self.df["ASS"] = self.df.ASS.astype(str) + duplicates_bool = self.df.ASS.duplicated() + duplicates = self.df.ASS[duplicates_bool] + if len(duplicates) > 0: + console.print( + textwrap.dedent( + """ + ############################################################################# + You have duplicated assignments in your list... + Currently each peak needs a unique assignment. Sorry about that buddy... + ############################################################################# + """ + ), + style="yellow", + ) + self.df.loc[duplicates_bool, "ASS"] = [ + f"{i}_dummy_{num+1}" for num, i in enumerate(duplicates) + ] + if self.verbose: + print("Here are the duplicates") + print(duplicates) + print(self.df.ASS) + + print( + textwrap.dedent( + """ + Creating dummy assignments for duplicates + + """ + ) + ) + + def check_peak_bounds(self): + columns_to_print = ["INDEX", "ASS", "X_AXIS", "Y_AXIS", "X_PPM", "Y_PPM"] + # check that peaks are within the bounds of spectrum + within_x = (self.df.X_PPM < self.f2_ppm_max) & (self.df.X_PPM > self.f2_ppm_min) + within_y = (self.df.Y_PPM < self.f1_ppm_max) & (self.df.Y_PPM > self.f1_ppm_min) + self.excluded = self.df[~(within_x & within_y)] + self.df = self.df[within_x & within_y] + if len(self.excluded) > 0: + print( + textwrap.dedent( + f"""[red] + ################################################################################# + + Excluding the following peaks as they are not within the spectrum which has shape + + {self.data.shape} + [/red]""" + ) + ) + table_to_print = df_to_rich_table( + self.excluded, + title="Excluded", + columns=columns_to_print, + styles=["red" for i in columns_to_print], + ) + print(table_to_print) + print( + "[red]#################################################################################[/red]" + ) + + def clusters( + self, + thres=None, + struc_el: StrucEl = StrucEl.disk, + struc_size=(3,), + l_struc=None, + ): + """Find clusters of peaks + + :param thres: threshold for positive signals above which clusters are selected. If None then threshold_otsu is used + :type thres: float + + :param struc_el: 'square'|'disk'|'rectangle' + structuring element for binary_closing of thresholded data can be square, disc or rectangle + :type struc_el: str + + :param struc_size: size/dimensions of structuring element + for square and disk first element of tuple is used (for disk value corresponds to radius) + for rectangle, tuple corresponds to (width,height). + :type struc_size: tuple + + + """ + peaks = [[y, x] for y, x in zip(self.df.Y_AXIS, self.df.X_AXIS)] + + if thres == None: + thres = self.thres + self._thres = abs(threshold_otsu(self.data[0])) + else: + self._thres = thres + + # get positive and negative + thresh_data = np.bitwise_or( + self.data[0] < (self._thres * -1.0), self.data[0] > self._thres + ) + + match struc_el: + case struc_el.disk: + radius = struc_size[0] + if self.verbose: + print(f"using disk with {radius}") + closed_data = binary_closing(thresh_data, disk(int(radius))) + + case struc_el.square: + width = struc_size[0] + if self.verbose: + print(f"using square with {width}") + closed_data = binary_closing(thresh_data, square(int(width))) + + case struc_el.rectangle: + width, height = struc_size + if self.verbose: + print(f"using rectangle with {width} and {height}") + closed_data = binary_closing( + thresh_data, rectangle(int(width), int(height)) + ) + + case _: + if self.verbose: + print(f"Not using any closing function") + closed_data = thresh_data + + labeled_array, num_features = ndimage.label(closed_data, l_struc) + + self.df.loc[:, "CLUSTID"] = [labeled_array[i[0], i[1]] for i in peaks] + + #  renumber "0" clusters + max_clustid = self.df["CLUSTID"].max() + n_of_zeros = len(self.df[self.df["CLUSTID"] == 0]["CLUSTID"]) + self.df.loc[self.df[self.df["CLUSTID"] == 0].index, "CLUSTID"] = np.arange( + max_clustid + 1, n_of_zeros + max_clustid + 1, dtype=int + ) + + # count how many peaks per cluster + for ind, group in self.df.groupby("CLUSTID"): + self.df.loc[group.index, "MEMCNT"] = len(group) + + self.df.loc[:, "color"] = self.df.apply( + lambda x: Category20[20][int(x.CLUSTID) % 20] if x.MEMCNT > 1 else "black", + axis=1, + ) + return ClustersResult(labeled_array, num_features, closed_data, peaks) + + def mask_method(self, overlap=1.0, l_struc=None): + """connect clusters based on overlap of fitting masks + + :param overlap: fraction of mask for which overlaps are calculated + :type overlap: float + + :returns ClusterResult: Instance of ClusterResult + :rtype: ClustersResult + """ + # overlap is positive + overlap = abs(overlap) + + self._thres = threshold_otsu(self.data[0]) + + mask = np.zeros(self.data[0].shape, dtype=bool) + + for ind, peak in self.df.iterrows(): + mask += make_mask( + self.data[0], + peak.X_AXISf, + peak.Y_AXISf, + peak.X_RADIUS * overlap, + peak.Y_RADIUS * overlap, + ) + + peaks = [[y, x] for y, x in zip(self.df.Y_AXIS, self.df.X_AXIS)] + labeled_array, num_features = ndimage.label(mask, l_struc) + + self.df.loc[:, "CLUSTID"] = [labeled_array[i[0], i[1]] for i in peaks] + + #  renumber "0" clusters + max_clustid = self.df["CLUSTID"].max() + n_of_zeros = len(self.df[self.df["CLUSTID"] == 0]["CLUSTID"]) + self.df.loc[self.df[self.df["CLUSTID"] == 0].index, "CLUSTID"] = np.arange( + max_clustid + 1, n_of_zeros + max_clustid + 1, dtype=int + ) + + # count how many peaks per cluster + for ind, group in self.df.groupby("CLUSTID"): + self.df.loc[group.index, "MEMCNT"] = len(group) + + self.df.loc[:, "color"] = self.df.apply( + lambda x: Category20[20][int(x.CLUSTID) % 20] if x.MEMCNT > 1 else "black", + axis=1, + ) + + return ClustersResult(labeled_array, num_features, mask, peaks) + + def to_fuda(self, fname="params.fuda"): + with open("peaks.fuda", "w") as peaks_fuda: + for ass, f1_ppm, f2_ppm in zip(self.df.ASS, self.df.Y_PPM, self.df.X_PPM): + peaks_fuda.write(f"{ass}\t{f1_ppm:.3f}\t{f2_ppm:.3f}\n") + groups = self.df.groupby("CLUSTID") + fuda_params = Path(fname) + overlap_peaks = "" + + for ind, group in groups: + if len(group) > 1: + overlap_peaks_str = ";".join(group.ASS) + overlap_peaks += f"OVERLAP_PEAKS=({overlap_peaks_str})\n" + + fuda_file = textwrap.dedent( + f"""\ + +# Read peaklist and spectrum info +PEAKLIST=peaks.fuda +SPECFILE={self.data_path} +PARAMETERFILE=(bruker;vclist) +ZCORR=ncyc +NOISE={self.thres} # you'll need to adjust this +BASELINE=N +VERBOSELEVEL=5 +PRINTDATA=Y +LM=(MAXFEV=250;TOL=1e-5) +#Specify the default values. All values are in ppm: +DEF_LINEWIDTH_F1={self.f1_radius} +DEF_LINEWIDTH_F2={self.f2_radius} +DEF_RADIUS_F1={self.f1_radius} +DEF_RADIUS_F2={self.f2_radius} +SHAPE=GLORE +# OVERLAP PEAKS +{overlap_peaks}""" + ) + with open(fuda_params, "w") as f: + print(f"Writing FuDA file {fuda_file}") + f.write(fuda_file) + if self.verbose: + print(overlap_peaks) + + +class ClustersResult: + """Class to store results of clusters function""" + + def __init__(self, labeled_array, num_features, closed_data, peaks): + self._labeled_array = labeled_array + self._num_features = num_features + self._closed_data = closed_data + self._peaks = peaks + + @property + def labeled_array(self): + return self._labeled_array + + @property + def num_features(self): + return self._num_features + + @property + def closed_data(self): + return self._closed_data + + @property + def peaks(self): + return self._peaks + + +class LoadData(Peaklist): + """Load peaklist data from peakipy .csv file output from either peakipy read or edit + + read_peaklist is redefined to just read a .csv file + + check_data_frame makes sure data frame is in good shape for setting up fits + + """ + + def read_peaklist(self): + if self.peaklist_path.suffix == ".csv": + self.df = pd.read_csv(self.peaklist_path) # , comment="#") + + elif self.peaklist_path.suffix == ".tab": + self.df = pd.read_csv(self.peaklist_path, sep="\t") # comment="#") + + else: + self.df = pd.read_pickle(self.peaklist_path) + + self._thres = threshold_otsu(self.data[0]) + + return self.df + + def check_data_frame(self): + # make diameter columns + if "X_DIAMETER_PPM" in self.df.columns: + pass + else: + self.df["X_DIAMETER_PPM"] = self.df["X_RADIUS_PPM"] * 2.0 + self.df["Y_DIAMETER_PPM"] = self.df["Y_RADIUS_PPM"] * 2.0 + + #  make a column to track edited peaks + if "Edited" in self.df.columns: + pass + else: + self.df["Edited"] = np.zeros(len(self.df), dtype=bool) + + # create include column if it doesn't exist + if "include" in self.df.columns: + pass + else: + self.df["include"] = self.df.apply(lambda _: "yes", axis=1) + + # color clusters + self.df["color"] = self.df.apply( + lambda x: Category20[20][int(x.CLUSTID) % 20] if x.MEMCNT > 1 else "black", + axis=1, + ) + + # get rid of unnamed columns + unnamed_cols = [i for i in self.df.columns if "Unnamed:" in i] + self.df = self.df.drop(columns=unnamed_cols) + + def update_df(self): + """Slightly modified to retain previous configurations""" + # int point value + self.df["X_AXIS"] = self.df.X_PPM.apply(lambda x: self.uc_f2(x, "ppm")) + self.df["Y_AXIS"] = self.df.Y_PPM.apply(lambda x: self.uc_f1(x, "ppm")) + # decimal point value + self.df["X_AXISf"] = self.df.X_PPM.apply(lambda x: self.uc_f2.f(x, "ppm")) + self.df["Y_AXISf"] = self.df.Y_PPM.apply(lambda x: self.uc_f1.f(x, "ppm")) + # in case of missing values (should estimate though) + self.df["XW_HZ"] = self.df.XW_HZ.replace(np.NaN, "20.0") + self.df["YW_HZ"] = self.df.YW_HZ.replace(np.NaN, "20.0") + # convert linewidths to float + self.df["XW_HZ"] = self.df.XW_HZ.apply(lambda x: float(x)) + self.df["YW_HZ"] = self.df.YW_HZ.apply(lambda x: float(x)) + # convert Hz lw to points + self.df["XW"] = self.df.XW_HZ.apply(lambda x: x * self.pt_per_hz_f2) + self.df["YW"] = self.df.YW_HZ.apply(lambda x: x * self.pt_per_hz_f1) + # makes an assignment column + if self.fmt == "a2": + self.df["ASS"] = self.df.apply( + lambda i: "".join([i["Assign F1"], i["Assign F2"]]), axis=1 + ) + + # make default values for X and Y radii for fit masks + # self.df["X_RADIUS_PPM"] = np.zeros(len(self.df)) + self.f2_radius + # self.df["Y_RADIUS_PPM"] = np.zeros(len(self.df)) + self.f1_radius + self.df["X_RADIUS"] = self.df.X_RADIUS_PPM.apply( + lambda x: x * self.pt_per_ppm_f2 + ) + self.df["Y_RADIUS"] = self.df.Y_RADIUS_PPM.apply( + lambda x: x * self.pt_per_ppm_f1 + ) + # add include column + if "include" in self.df.columns: + pass + else: + self.df["include"] = self.df.apply(lambda x: "yes", axis=1) + + # check assignments for duplicates + self.check_assignments() + # check that peaks are within the bounds of the data + self.check_peak_bounds() + + +def get_vclist(vclist, args): + # read vclist + if vclist is None: + vclist = False + elif vclist.exists(): + vclist_data = np.genfromtxt(vclist) + args["vclist_data"] = vclist_data + vclist = True + else: + raise Exception("vclist not found...") + + args["vclist"] = vclist + return args diff --git a/peakipy/lineshapes.py b/peakipy/lineshapes.py new file mode 100644 index 00000000..38b53e26 --- /dev/null +++ b/peakipy/lineshapes.py @@ -0,0 +1,522 @@ +from enum import Enum + +import pandas as pd +from numpy import sqrt, exp, log +from scipy.special import wofz + +from peakipy.constants import π, tiny, log2 + + +class Lineshape(str, Enum): + PV = "PV" + V = "V" + G = "G" + L = "L" + PV_PV = "PV_PV" + G_L = "G_L" + PV_G = "PV_G" + PV_L = "PV_L" + + +def gaussian(x, center=0.0, sigma=1.0): + r"""1-dimensional Gaussian function. + + gaussian(x, center, sigma) = + (1/(s2pi*sigma)) * exp(-(1.0*x-center)**2 / (2*sigma**2)) + + :math:`\\frac{1}{ \sqrt{2\pi} } exp \left( \\frac{-(x-center)^2}{2 \sigma^2} \\right)` + + :param x: x + :param center: center + :param sigma: sigma + :type x: numpy.array + :type center: float + :type sigma: float + + :return: 1-dimensional Gaussian + :rtype: numpy.array + + """ + return (1.0 / max(tiny, (sqrt(2 * π) * sigma))) * exp( + -((1.0 * x - center) ** 2) / max(tiny, (2 * sigma**2)) + ) + + +def lorentzian(x, center=0.0, sigma=1.0): + r"""1-dimensional Lorentzian function. + + lorentzian(x, center, sigma) = + (1/(1 + ((1.0*x-center)/sigma)**2)) / (pi*sigma) + + :math:`\\frac{1}{ 1+ \left( \\frac{x-center}{\sigma}\\right)^2} / (\pi\sigma)` + + :param x: x + :param center: center + :param sigma: sigma + :type x: numpy.array + :type center: float + :type sigma: float + + :return: 1-dimensional Lorenztian + :rtype: numpy.array + + """ + return (1.0 / (1 + ((1.0 * x - center) / max(tiny, sigma)) ** 2)) / max( + tiny, (π * sigma) + ) + + +def voigt(x, center=0.0, sigma=1.0, gamma=None): + r"""Return a 1-dimensional Voigt function. + + voigt(x, center, sigma, gamma) = + amplitude*wofz(z).real / (sigma*sqrt(2.0 * π)) + + :math:`V(x,\sigma,\gamma) = (\\frac{Re[\omega(z)]}{\sigma \sqrt{2\pi}})` + + :math:`z=\\frac{x+i\gamma}{\sigma\sqrt{2}}` + + see Voigt_ wiki + + .. _Voigt: https://en.wikipedia.org/wiki/Voigt_profile + + + :param x: x values + :type x: numpy array 1d + :param center: center of lineshape in points + :type center: float + :param sigma: sigma of gaussian + :type sigma: float + :param gamma: gamma of lorentzian + :type gamma: float + + :returns: Voigt lineshape + :rtype: numpy.array + + """ + if gamma is None: + gamma = sigma + + z = (x - center + 1j * gamma) / max(tiny, (sigma * sqrt(2.0))) + return wofz(z).real / max(tiny, (sigma * sqrt(2.0 * π))) + + +def pseudo_voigt(x, center=0.0, sigma=1.0, fraction=0.5): + r"""1-dimensional Pseudo-voigt function + + Superposition of Gaussian and Lorentzian function + + :math:`(1-\phi) G(x,center,\sigma_g) + \phi L(x, center, \sigma)` + + Where :math:`\phi` is the fraction of Lorentzian lineshape and :math:`G` and :math:`L` are Gaussian and + Lorentzian functions, respectively. + + :param x: data + :type x: numpy.array + :param center: center of peak + :type center: float + :param sigma: sigma of lineshape + :type sigma: float + :param fraction: fraction of lorentzian lineshape (between 0 and 1) + :type fraction: float + + :return: pseudo-voigt function + :rtype: numpy.array + + """ + sigma_g = sigma / sqrt(2 * log2) + pv = (1 - fraction) * gaussian(x, center, sigma_g) + fraction * lorentzian( + x, center, sigma + ) + return pv + + +def pvoigt2d( + XY, + amplitude=1.0, + center_x=0.5, + center_y=0.5, + sigma_x=1.0, + sigma_y=1.0, + fraction=0.5, +): + r"""2D pseudo-voigt model + + :math:`(1-fraction) G(x,center,\sigma_{gx}) + (fraction) L(x, center, \sigma_x) * (1-fraction) G(y,center,\sigma_{gy}) + (fraction) L(y, center, \sigma_y)` + + :param XY: meshgrid of X and Y coordinates [X,Y] each with shape Z + :type XY: numpy.array + + :param amplitude: amplitude of peak + :type amplitude: float + + :param center_x: center of peak in x + :type center_x: float + + :param center_y: center of peak in x + :type center_y: float + + :param sigma_x: sigma of lineshape in x + :type sigma_x: float + + :param sigma_y: sigma of lineshape in y + :type sigma_y: float + + :param fraction: fraction of lorentzian lineshape (between 0 and 1) + :type fraction: float + + :return: flattened array of Z values (use Z.reshape(X.shape) for recovery) + :rtype: numpy.array + + """ + x, y = XY + pv_x = pseudo_voigt(x, center_x, sigma_x, fraction) + pv_y = pseudo_voigt(y, center_y, sigma_y, fraction) + return amplitude * pv_x * pv_y + + +def pv_l( + XY, + amplitude=1.0, + center_x=0.5, + center_y=0.5, + sigma_x=1.0, + sigma_y=1.0, + fraction=0.5, +): + """2D lineshape model with pseudo-voigt in x and lorentzian in y + + Arguments + ========= + + -- XY: meshgrid of X and Y coordinates [X,Y] each with shape Z + -- amplitude: peak amplitude (gaussian and lorentzian) + -- center_x: position of peak in x + -- center_y: position of peak in y + -- sigma_x: linewidth in x + -- sigma_y: linewidth in y + -- fraction: fraction of lorentzian in fit + + Returns + ======= + + -- flattened array of Z values (use Z.reshape(X.shape) for recovery) + + """ + + x, y = XY + pv_x = pseudo_voigt(x, center_x, sigma_x, fraction) + pv_y = pseudo_voigt(y, center_y, sigma_y, 1.0) # lorentzian + return amplitude * pv_x * pv_y + + +def pv_g( + XY, + amplitude=1.0, + center_x=0.5, + center_y=0.5, + sigma_x=1.0, + sigma_y=1.0, + fraction=0.5, +): + """2D lineshape model with pseudo-voigt in x and gaussian in y + + Arguments + --------- + + -- XY: meshgrid of X and Y coordinates [X,Y] each with shape Z + -- amplitude: peak amplitude (gaussian and lorentzian) + -- center_x: position of peak in x + -- center_y: position of peak in y + -- sigma_x: linewidth in x + -- sigma_y: linewidth in y + -- fraction: fraction of lorentzian in fit + + Returns + ------- + + -- flattened array of Z values (use Z.reshape(X.shape) for recovery) + + """ + x, y = XY + pv_x = pseudo_voigt(x, center_x, sigma_x, fraction) + pv_y = pseudo_voigt(y, center_y, sigma_y, 0.0) # gaussian + return amplitude * pv_x * pv_y + + +def pv_pv( + XY, + amplitude=1.0, + center_x=0.5, + center_y=0.5, + sigma_x=1.0, + sigma_y=1.0, + fraction_x=0.5, + fraction_y=0.5, +): + """2D lineshape model with pseudo-voigt in x and pseudo-voigt in y + i.e. fraction_x and fraction_y params + + Arguments + ========= + + -- XY: meshgrid of X and Y coordinates [X,Y] each with shape Z + -- amplitude: peak amplitude (gaussian and lorentzian) + -- center_x: position of peak in x + -- center_y: position of peak in y + -- sigma_x: linewidth in x + -- sigma_y: linewidth in y + -- fraction_x: fraction of lorentzian in x + -- fraction_y: fraction of lorentzian in y + + Returns + ======= + + -- flattened array of Z values (use Z.reshape(X.shape) for recovery) + + """ + + x, y = XY + pv_x = pseudo_voigt(x, center_x, sigma_x, fraction_x) + pv_y = pseudo_voigt(y, center_y, sigma_y, fraction_y) + return amplitude * pv_x * pv_y + + +def gaussian_lorentzian( + XY, + amplitude=1.0, + center_x=0.5, + center_y=0.5, + sigma_x=1.0, + sigma_y=1.0, + fraction=0.5, +): + """2D lineshape model with gaussian in x and lorentzian in y + + Arguments + ========= + + -- XY: meshgrid of X and Y coordinates [X,Y] each with shape Z + -- amplitude: peak amplitude (gaussian and lorentzian) + -- center_x: position of peak in x + -- center_y: position of peak in y + -- sigma_x: linewidth in x + -- sigma_y: linewidth in y + -- fraction: fraction of lorentzian in fit + + Returns + ======= + + -- flattened array of Z values (use Z.reshape(X.shape) for recovery) + + """ + x, y = XY + pv_x = pseudo_voigt(x, center_x, sigma_x, 0.0) # gaussian + pv_y = pseudo_voigt(y, center_y, sigma_y, 1.0) # lorentzian + return amplitude * pv_x * pv_y + + +def voigt2d( + XY, + amplitude=1.0, + center_x=0.5, + center_y=0.5, + sigma_x=1.0, + sigma_y=1.0, + gamma_x=1.0, + gamma_y=1.0, + fraction=0.5, +): + fraction = 0.5 + gamma_x = None + gamma_y = None + x, y = XY + voigt_x = voigt(x, center_x, sigma_x, gamma_x) + voigt_y = voigt(y, center_y, sigma_y, gamma_y) + return amplitude * voigt_x * voigt_y + + +def get_lineshape_function(lineshape: Lineshape): + match lineshape: + case lineshape.PV | lineshape.G | lineshape.L: + lineshape_function = pvoigt2d + case lineshape.V: + lineshape_function = voigt2d + case lineshape.PV_PV: + lineshape_function = pv_pv + case lineshape.G_L: + lineshape_function = gaussian_lorentzian + case lineshape.PV_G: + lineshape_function = pv_g + case lineshape.PV_L: + lineshape_function = pv_l + case _: + raise Exception("No lineshape was selected!") + return lineshape_function + + +def calculate_height_for_voigt_lineshape(df): + df["height"] = df.apply( + lambda x: voigt2d( + XY=[0, 0], + center_x=0.0, + center_y=0.0, + sigma_x=x.sigma_x, + sigma_y=x.sigma_y, + gamma_x=x.gamma_x, + gamma_y=x.gamma_y, + amplitude=x.amp, + ), + axis=1, + ) + df["height_err"] = df.apply( + lambda x: x.amp_err * (x.height / x.amp) if x.amp_err != None else 0.0, + axis=1, + ) + return df + + +def calculate_fwhm_for_voigt_lineshape(df): + df["fwhm_g_x"] = df.sigma_x.apply( + lambda x: 2.0 * x * sqrt(2.0 * log(2.0)) + ) # fwhm of gaussian + df["fwhm_g_y"] = df.sigma_y.apply(lambda x: 2.0 * x * sqrt(2.0 * log(2.0))) + df["fwhm_l_x"] = df.gamma_x.apply(lambda x: 2.0 * x) # fwhm of lorentzian + df["fwhm_l_y"] = df.gamma_y.apply(lambda x: 2.0 * x) + df["fwhm_x"] = df.apply( + lambda x: 0.5346 * x.fwhm_l_x + + sqrt(0.2166 * x.fwhm_l_x**2.0 + x.fwhm_g_x**2.0), + axis=1, + ) + df["fwhm_y"] = df.apply( + lambda x: 0.5346 * x.fwhm_l_y + + sqrt(0.2166 * x.fwhm_l_y**2.0 + x.fwhm_g_y**2.0), + axis=1, + ) + return df + + +def calculate_height_for_pseudo_voigt_lineshape(df): + df["height"] = df.apply( + lambda x: pvoigt2d( + XY=[0, 0], + center_x=0.0, + center_y=0.0, + sigma_x=x.sigma_x, + sigma_y=x.sigma_y, + amplitude=x.amp, + fraction=x.fraction, + ), + axis=1, + ) + df["height_err"] = df.apply(lambda x: x.amp_err * (x.height / x.amp), axis=1) + return df + + +def calculate_fwhm_for_pseudo_voigt_lineshape(df): + df["fwhm_x"] = df.sigma_x.apply(lambda x: x * 2.0) + df["fwhm_y"] = df.sigma_y.apply(lambda x: x * 2.0) + return df + + +def calculate_height_for_gaussian_lineshape(df): + df["height"] = df.apply( + lambda x: pvoigt2d( + XY=[0, 0], + center_x=0.0, + center_y=0.0, + sigma_x=x.sigma_x, + sigma_y=x.sigma_y, + amplitude=x.amp, + fraction=0.0, # gaussian + ), + axis=1, + ) + df["height_err"] = df.apply(lambda x: x.amp_err * (x.height / x.amp), axis=1) + return df + + +def calculate_height_for_lorentzian_lineshape(df): + df["height"] = df.apply( + lambda x: pvoigt2d( + XY=[0, 0], + center_x=0.0, + center_y=0.0, + sigma_x=x.sigma_x, + sigma_y=x.sigma_y, + amplitude=x.amp, + fraction=1.0, # lorentzian + ), + axis=1, + ) + df["height_err"] = df.apply(lambda x: x.amp_err * (x.height / x.amp), axis=1) + return df + + +def calculate_height_for_pv_pv_lineshape(df): + df["height"] = df.apply( + lambda x: pv_pv( + XY=[0, 0], + center_x=0.0, + center_y=0.0, + sigma_x=x.sigma_x, + sigma_y=x.sigma_y, + amplitude=x.amp, + fraction_x=x.fraction_x, + fraction_y=x.fraction_y, + ), + axis=1, + ) + df["height_err"] = df.apply(lambda x: x.amp_err * (x.height / x.amp), axis=1) + return df + + +def calculate_peak_centers_in_ppm(df, peakipy_data): + #  convert values to ppm + df["center_x_ppm"] = df.center_x.apply(lambda x: peakipy_data.uc_f2.ppm(x)) + df["center_y_ppm"] = df.center_y.apply(lambda x: peakipy_data.uc_f1.ppm(x)) + df["init_center_x_ppm"] = df.init_center_x.apply( + lambda x: peakipy_data.uc_f2.ppm(x) + ) + df["init_center_y_ppm"] = df.init_center_y.apply( + lambda x: peakipy_data.uc_f1.ppm(x) + ) + return df + + +def calculate_peak_linewidths_in_hz(df, peakipy_data): + df["sigma_x_ppm"] = df.sigma_x.apply(lambda x: x * peakipy_data.ppm_per_pt_f2) + df["sigma_y_ppm"] = df.sigma_y.apply(lambda x: x * peakipy_data.ppm_per_pt_f1) + df["fwhm_x_ppm"] = df.fwhm_x.apply(lambda x: x * peakipy_data.ppm_per_pt_f2) + df["fwhm_y_ppm"] = df.fwhm_y.apply(lambda x: x * peakipy_data.ppm_per_pt_f1) + df["fwhm_x_hz"] = df.fwhm_x.apply(lambda x: x * peakipy_data.hz_per_pt_f2) + df["fwhm_y_hz"] = df.fwhm_y.apply(lambda x: x * peakipy_data.hz_per_pt_f1) + return df + + +def calculate_lineshape_specific_height_and_fwhm( + lineshape: Lineshape, df: pd.DataFrame +): + match lineshape: + case lineshape.V: + df = calculate_height_for_voigt_lineshape(df) + df = calculate_fwhm_for_voigt_lineshape(df) + + case lineshape.PV: + df = calculate_height_for_pseudo_voigt_lineshape(df) + df = calculate_fwhm_for_pseudo_voigt_lineshape(df) + + case lineshape.G: + df = calculate_height_for_gaussian_lineshape(df) + df = calculate_fwhm_for_pseudo_voigt_lineshape(df) + + case lineshape.L: + df = calculate_height_for_lorentzian_lineshape(df) + df = calculate_fwhm_for_pseudo_voigt_lineshape(df) + + case lineshape.PV_PV: + df = calculate_height_for_pv_pv_lineshape(df) + df = calculate_fwhm_for_pseudo_voigt_lineshape(df) + case _: + df = calculate_fwhm_for_pseudo_voigt_lineshape(df) + return df diff --git a/peakipy/plotting.py b/peakipy/plotting.py new file mode 100644 index 00000000..479f4ace --- /dev/null +++ b/peakipy/plotting.py @@ -0,0 +1,400 @@ +from dataclasses import dataclass, field +from typing import List + +import pandas as pd +import numpy as np +import plotly.graph_objects as go +import matplotlib.pyplot as plt +from matplotlib import cm +from matplotlib.widgets import Button +from matplotlib.backends.backend_pdf import PdfPages +from rich import print + +from peakipy.io import Pseudo3D +from peakipy.utils import df_to_rich_table, bad_color_selection, bad_column_selection + + +@dataclass +class PlottingDataForPlane: + pseudo3D: Pseudo3D + plane_id: int + plane_lineshape_parameters: pd.DataFrame + X: np.array + Y: np.array + mask: np.array + individual_masks: List[np.array] + sim_data: np.array + sim_data_singles: List[np.array] + min_x: int + max_x: int + min_y: int + max_y: int + fit_color: str + data_color: str + rcount: int + ccount: int + + x_plot: np.array = field(init=False) + y_plot: np.array = field(init=False) + masked_data: np.array = field(init=False) + masked_sim_data: np.array = field(init=False) + residual: np.array = field(init=False) + single_colors: List = field(init=False) + + def __post_init__(self): + self.plane_data = self.pseudo3D.data[self.plane_id] + self.masked_data = self.plane_data.copy() + self.masked_sim_data = self.sim_data.copy() + self.masked_data[~self.mask] = np.nan + self.masked_sim_data[~self.mask] = np.nan + + self.x_plot = self.pseudo3D.uc_f2.ppm( + self.X[self.min_y : self.max_y, self.min_x : self.max_x] + ) + self.y_plot = self.pseudo3D.uc_f1.ppm( + self.Y[self.min_y : self.max_y, self.min_x : self.max_x] + ) + self.masked_data = self.masked_data[ + self.min_y : self.max_y, self.min_x : self.max_x + ] + self.sim_plot = self.masked_sim_data[ + self.min_y : self.max_y, self.min_x : self.max_x + ] + self.residual = self.masked_data - self.sim_plot + + for single_mask, single in zip(self.individual_masks, self.sim_data_singles): + single[~single_mask] = np.nan + self.sim_data_singles = [ + sim_data_single[self.min_y : self.max_y, self.min_x : self.max_x] + for sim_data_single in self.sim_data_singles + ] + self.single_colors = [ + cm.viridis(i) for i in np.linspace(0, 1, len(self.sim_data_singles)) + ] + + +def plot_data_is_valid(plot_data: PlottingDataForPlane) -> bool: + if len(plot_data.x_plot) < 1 or len(plot_data.y_plot) < 1: + print( + f"[red]Nothing to plot for cluster {int(plot_data.plane_lineshape_parameters.clustid)}[/red]" + ) + print(f"[red]x={plot_data.x_plot},y={plot_data.y_plot}[/red]") + print( + df_to_rich_table( + plot_data.plane_lineshape_parameters, + title="", + columns=bad_column_selection, + styles=bad_color_selection, + ) + ) + plt.close() + validated = False + # print(Fore.RED + "Maybe your F1/F2 radii for fitting were too small...") + elif plot_data.masked_data.shape[0] == 0 or plot_data.masked_data.shape[1] == 0: + print(f"[red]Nothing to plot for cluster {int(plot_data.plane.clustid)}[/red]") + print( + df_to_rich_table( + plot_data.plane_lineshape_parameters, + title="Bad plane", + columns=bad_column_selection, + styles=bad_color_selection, + ) + ) + spec_lim_f1 = " - ".join( + ["%8.3f" % i for i in plot_data.pseudo3D.f1_ppm_limits] + ) + spec_lim_f2 = " - ".join( + ["%8.3f" % i for i in plot_data.pseudo3D.f2_ppm_limits] + ) + print(f"Spectrum limits are {plot_data.pseudo3D.f2_label:4s}:{spec_lim_f2} ppm") + print(f" {plot_data.pseudo3D.f1_label:4s}:{spec_lim_f1} ppm") + plt.close() + validated = False + else: + validated = True + return validated + + +def create_matplotlib_figure( + plot_data: PlottingDataForPlane, + pdf: PdfPages, + individual=False, + label=False, + ccpn_flag=False, + show=True, +): + fig = plt.figure(figsize=(10, 6)) + ax = fig.add_subplot(projection="3d") + if plot_data_is_valid(plot_data): + cset = ax.contourf( + plot_data.x_plot, + plot_data.y_plot, + plot_data.residual, + zdir="z", + offset=np.nanmin(plot_data.masked_data) * 1.1, + alpha=0.5, + cmap=cm.coolwarm, + ) + cbl = fig.colorbar(cset, ax=ax, shrink=0.5, format="%.2e") + cbl.ax.set_title("Residual", pad=20) + + if individual: + #  for plotting single fit surfaces + single_colors = [ + cm.viridis(i) + for i in np.linspace(0, 1, len(plot_data.sim_data_singles)) + ] + [ + ax.plot_surface( + plot_data.x_plot, + plot_data.y_plot, + z_single, + color=c, + alpha=0.5, + ) + for c, z_single in zip(single_colors, plot_data.sim_data_singles) + ] + ax.plot_wireframe( + plot_data.x_plot, + plot_data.y_plot, + plot_data.sim_plot, + # colors=[cm.coolwarm(i) for i in np.ravel(residual)], + colors=plot_data.fit_color, + linestyle="--", + label="fit", + rcount=plot_data.rcount, + ccount=plot_data.ccount, + ) + ax.plot_wireframe( + plot_data.x_plot, + plot_data.y_plot, + plot_data.masked_data, + colors=plot_data.data_color, + linestyle="-", + label="data", + rcount=plot_data.rcount, + ccount=plot_data.ccount, + ) + ax.set_ylabel(plot_data.pseudo3D.f1_label) + ax.set_xlabel(plot_data.pseudo3D.f2_label) + + # axes will appear inverted + ax.view_init(30, 120) + + title = f"Plane={plot_data.plane_id},Cluster={plot_data.plane_lineshape_parameters.clustid.iloc[0]}" + plt.title(title) + print(f"[green]Plotting: {title}[/green]") + out_str = "Volumes (Heights)\n===========\n" + for _, row in plot_data.plane_lineshape_parameters.iterrows(): + out_str += f"{row.assignment} = {row.amp:.3e} ({row.height:.3e})\n" + if label: + ax.text( + row.center_x_ppm, + row.center_y_ppm, + row.height * 1.2, + row.assignment, + (1, 1, 1), + ) + + ax.text2D( + -0.5, + 1.0, + out_str, + transform=ax.transAxes, + fontsize=10, + fontfamily="sans-serif", + va="top", + bbox=dict(boxstyle="round", ec="k", fc="k", alpha=0.5), + ) + + ax.legend() + + if show: + + def exit_program(event): + exit() + + def next_plot(event): + plt.close() + + axexit = plt.axes([0.81, 0.05, 0.1, 0.075]) + bnexit = Button(axexit, "Exit") + bnexit.on_clicked(exit_program) + axnext = plt.axes([0.71, 0.05, 0.1, 0.075]) + bnnext = Button(axnext, "Next") + bnnext.on_clicked(next_plot) + if ccpn_flag: + plt.show(windowTitle="", size=(1000, 500)) + else: + plt.show() + else: + pdf.savefig() + + plt.close() + + +def create_plotly_wireframe_lines(plot_data: PlottingDataForPlane): + lines = [] + show_legend = lambda x: x < 1 + showlegend = False + # make simulated data wireframe + line_marker = dict(color=plot_data.fit_color, width=4) + counter = 0 + for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.sim_plot): + showlegend = show_legend(counter) + lines.append( + go.Scatter3d( + x=i, + y=j, + z=k, + mode="lines", + line=line_marker, + name="fit", + showlegend=showlegend, + ) + ) + counter += 1 + for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.sim_plot.T): + lines.append( + go.Scatter3d( + x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend + ) + ) + # make experimental data wireframe + line_marker = dict(color=plot_data.data_color, width=4) + counter = 0 + for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.masked_data): + showlegend = show_legend(counter) + lines.append( + go.Scatter3d( + x=i, + y=j, + z=k, + mode="lines", + name="data", + line=line_marker, + showlegend=showlegend, + ) + ) + counter += 1 + for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.masked_data.T): + lines.append( + go.Scatter3d( + x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend + ) + ) + + return lines + + +def construct_surface_legend_string(row): + surface_legend = "" + surface_legend += row.assignment + return surface_legend + + +def create_plotly_surfaces(plot_data: PlottingDataForPlane): + data = [] + color_scale_values = np.linspace(0, 1, len(plot_data.single_colors)) + color_scale = [ + [val, f"rgb({', '.join('%d'%(i*255) for i in c[0:3])})"] + for val, c in zip(color_scale_values, plot_data.single_colors) + ] + for val, individual_peak, row in zip( + color_scale_values, + plot_data.sim_data_singles, + plot_data.plane_lineshape_parameters.itertuples(), + ): + name = construct_surface_legend_string(row) + colors = np.zeros(shape=individual_peak.shape) + val + data.append( + go.Surface( + z=individual_peak, + x=plot_data.x_plot, + y=plot_data.y_plot, + opacity=0.5, + surfacecolor=colors, + colorscale=color_scale, + showscale=False, + cmin=0, + cmax=1, + name=name, + ) + ) + return data + + +def create_residual_contours(plot_data: PlottingDataForPlane): + contours = go.Contour( + x=plot_data.x_plot[0], y=plot_data.y_plot.T[0], z=plot_data.residual + ) + return contours + + +def create_residual_figure(plot_data: PlottingDataForPlane): + data = create_residual_contours(plot_data) + fig = go.Figure(data=data) + fig.update_layout( + title="Fit residuals", + xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm", + yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm", + xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]), + yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]), + ) + return fig + + +def create_plotly_figure(plot_data: PlottingDataForPlane): + lines = create_plotly_wireframe_lines(plot_data) + surfaces = create_plotly_surfaces(plot_data) + fig = go.Figure(data=lines + surfaces) + fig = update_axis_ranges(fig, plot_data) + return fig + + +def update_axis_ranges(fig, plot_data: PlottingDataForPlane): + fig.update_layout( + scene=dict( + xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]), + yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]), + xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm", + yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm", + annotations=make_annotations(plot_data), + ) + ) + return fig + + +def make_annotations(plot_data: PlottingDataForPlane): + annotations = [] + for row in plot_data.plane_lineshape_parameters.itertuples(): + annotations.append( + dict( + showarrow=True, + x=row.center_x_ppm, + y=row.center_y_ppm, + z=row.height * 1.0, + text=row.assignment, + opacity=0.8, + textangle=0, + arrowsize=1, + ) + ) + return annotations + + +def validate_sample_count(sample_count): + if type(sample_count) == int: + sample_count = sample_count + else: + raise TypeError("Sample count (ccount, rcount) should be an integer") + return sample_count + + +def unpack_plotting_colors(colors): + match colors: + case (data_color, fit_color): + data_color, fit_color = colors + case _: + data_color, fit_color = "green", "blue" + return data_color, fit_color diff --git a/peakipy/utils.py b/peakipy/utils.py new file mode 100644 index 00000000..310ee656 --- /dev/null +++ b/peakipy/utils.py @@ -0,0 +1,239 @@ +import sys +import json +from datetime import datetime +from pathlib import Path +from typing import List + +from rich import print +from rich.table import Table + +# for printing dataframes +peaklist_columns_for_printing = ["INDEX", "ASS", "X_PPM", "Y_PPM", "CLUSTID", "MEMCNT"] +bad_column_selection = [ + "clustid", + "amp", + "center_x_ppm", + "center_y_ppm", + "fwhm_x_hz", + "fwhm_y_hz", + "lineshape", +] +bad_color_selection = [ + "green", + "blue", + "yellow", + "red", + "yellow", + "red", + "magenta", +] + + +def run_log(log_name="run_log.txt"): + """Write log file containing time script was run and with which arguments""" + with open(log_name, "a") as log: + sys_argv = sys.argv + sys_argv[0] = Path(sys_argv[0]).name + run_args = " ".join(sys_argv) + time_stamp = datetime.now() + time_stamp = time_stamp.strftime("%A %d %B %Y at %H:%M") + log.write(f"# Script run on {time_stamp}:\n{run_args}\n") + + +def df_to_rich_table(df, title: str, columns: List[str], styles: str): + """Print dataframe using rich library + + Parameters + ---------- + df : pandas.DataFrame + title : str + title of table + columns : List[str] + list of column names (must be in df) + styles : List[str] + list of styles in same order as columns + """ + table = Table(title=title) + for col, style in zip(columns, styles): + table.add_column(col, style=style) + for _, row in df.iterrows(): + row = row[columns].values + str_row = [] + for i in row: + match i: + case str(): + str_row.append(f"{i}") + case float() if i > 1e5: + str_row.append(f"{i:.1e}") + case float(): + str_row.append(f"{i:.3f}") + case bool(): + str_row.append(f"{i}") + case int(): + str_row.append(f"{i}") + table.add_row(*str_row) + return table + + +def load_config(config_path): + if config_path.exists(): + with open(config_path) as opened_config: + config_dic = json.load(opened_config) + return config_dic + else: + return {} + + +def write_config(config_path, config_dic): + with open(config_path, "w") as config: + config.write(json.dumps(config_dic, sort_keys=True, indent=4)) + + +def update_config_file(config_path, config_kvs): + config_dic = load_config(config_path) + config_dic.update(config_kvs) + write_config(config_path, config_dic) + return config_dic + + +def update_args_with_values_from_config_file(args, config_path="peakipy.config"): + """read a peakipy config file, extract params and update args dict + + :param args: dict containing params extracted from docopt command line + :type args: dict + :param config_path: path to peakipy config file [default: peakipy.config] + :type config_path: str + + :returns args: updated args dict + :rtype args: dict + :returns config: dict that resulted from reading config file + :rtype config: dict + + """ + # update args with values from peakipy.config file + config_path = Path(config_path) + if config_path.exists(): + try: + config = load_config(config_path) + print( + f"[green]Using config file with dims [yellow]{config.get('dims')}[/yellow][/green]" + ) + args["dims"] = config.get("dims", (0, 1, 2)) + noise = config.get("noise") + if noise: + noise = float(noise) + + colors = config.get("colors", ["#5e3c99", "#e66101"]) + except json.decoder.JSONDecodeError: + print( + "[red]Your peakipy.config file is corrupted - maybe your JSON is not correct...[/red]" + ) + print("[red]Not using[/red]") + noise = False + colors = args.get("colors", ("#5e3c99", "#e66101")) + config = {} + else: + print( + "[red]No peakipy.config found - maybe you need to generate one with peakipy read or see docs[/red]" + ) + noise = False + colors = args.get("colors", ("#5e3c99", "#e66101")) + config = {} + + args["noise"] = noise + args["colors"] = colors + + return args, config + + +def update_linewidths_from_hz_to_points(peakipy_data): + """in case they were adjusted when running edit.py""" + peakipy_data.df["XW"] = peakipy_data.df.XW_HZ * peakipy_data.pt_per_hz_f2 + peakipy_data.df["YW"] = peakipy_data.df.YW_HZ * peakipy_data.pt_per_hz_f1 + return peakipy_data + + +def update_peak_positions_from_ppm_to_points(peakipy_data): + # convert peak positions from ppm to points in case they were adjusted running edit.py + peakipy_data.df["X_AXIS"] = peakipy_data.df.X_PPM.apply( + lambda x: peakipy_data.uc_f2(x, "PPM") + ) + peakipy_data.df["Y_AXIS"] = peakipy_data.df.Y_PPM.apply( + lambda x: peakipy_data.uc_f1(x, "PPM") + ) + peakipy_data.df["X_AXISf"] = peakipy_data.df.X_PPM.apply( + lambda x: peakipy_data.uc_f2.f(x, "PPM") + ) + peakipy_data.df["Y_AXISf"] = peakipy_data.df.Y_PPM.apply( + lambda x: peakipy_data.uc_f1.f(x, "PPM") + ) + return peakipy_data + + +def save_data(df, output_name): + suffix = output_name.suffix + + if suffix == ".csv": + df.to_csv(output_name, float_format="%.4f", index=False) + + elif suffix == ".tab": + df.to_csv(output_name, sep="\t", float_format="%.4f", index=False) + + else: + df.to_pickle(output_name) + + +def check_data_shape_is_consistent_with_dims(peakipy_data): + # check data shape is consistent with dims + if len(peakipy_data.dims) != len(peakipy_data.data.shape): + print( + f"Dims are {peakipy_data.dims} while data shape is {peakipy_data.data.shape}?" + ) + exit() + + +def check_for_include_column_and_add_if_missing(peakipy_data): + # only include peaks with 'include' + if "include" in peakipy_data.df.columns: + pass + else: + # for compatibility + peakipy_data.df["include"] = peakipy_data.df.apply(lambda _: "yes", axis=1) + return peakipy_data + + +def remove_excluded_peaks(peakipy_data): + if len(peakipy_data.df[peakipy_data.df.include != "yes"]) > 0: + excluded = peakipy_data.df[peakipy_data.df.include != "yes"][ + peaklist_columns_for_printing + ] + table = df_to_rich_table( + excluded, + title="[yellow] Excluded peaks [/yellow]", + columns=excluded.columns, + styles=["yellow" for i in excluded.columns], + ) + print(table) + peakipy_data.df = peakipy_data.df[peakipy_data.df.include == "yes"] + return peakipy_data + + +def warn_if_trying_to_fit_large_clusters(max_cluster_size, peakipy_data): + if max_cluster_size is None: + max_cluster_size = peakipy_data.df.MEMCNT.max() + if peakipy_data.df.MEMCNT.max() > 10: + print( + f"""[red] + ################################################################## + You have some clusters of as many as {max_cluster_size} peaks. + You may want to consider reducing the size of your clusters as the + fits will struggle. + + Otherwise you can use the --max-cluster-size flag to exclude large + clusters + ################################################################## + [/red]""" + ) + else: + max_cluster_size = max_cluster_size + return max_cluster_size diff --git a/test/test_cli.py b/test/test_cli.py index aec9c35c..d30404c4 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -6,12 +6,10 @@ import peakipy.cli.main from peakipy.cli.main import PeaklistFormat, Lineshape -os.chdir("test") - @pytest.fixture def protein_L(): - path = Path("test_protein_L") + path = Path("test/test_protein_L") return path diff --git a/test/test_data.py b/test/test_data.py index 2a37150e..0b105bf8 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -6,15 +6,8 @@ from mpl_toolkits.mplot3d import Axes3D from lmfit import Model, report_fit -from peakipy.core import ( - pvoigt2d, - fix_params, - get_params, - make_mask, - # fit_first_plane, - make_models, - Lineshape, -) +from peakipy.lineshapes import pvoigt2d, Lineshape +from peakipy.fitting import make_mask, make_models def fit_first_plane( @@ -153,7 +146,7 @@ def fit_first_plane( for p in peaks: data += pvoigt2d( XY, - *p + *p, # amplitude=1e8, # center_x=200, # center_y=100, diff --git a/test/test_edit.py b/test/test_edit.py new file mode 100644 index 00000000..ab01e4ab --- /dev/null +++ b/test/test_edit.py @@ -0,0 +1 @@ +import panel as pn diff --git a/test/test_fit.py b/test/test_fit.py index 83512aed..c41d0ebf 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -25,7 +25,7 @@ FitPeaksArgs, FitPeaksInput, ) -from peakipy.core import Lineshape, pvoigt2d +from peakipy.lineshapes import Lineshape, pvoigt2d def test_get_fit_peaks_result_validation_model_PVPV(): diff --git a/test/test_fitting.py b/test/test_fitting.py new file mode 100644 index 00000000..01c9ed38 --- /dev/null +++ b/test/test_fitting.py @@ -0,0 +1,692 @@ +import unittest +from pathlib import Path +from collections import namedtuple + +import numpy as np +import pandas as pd +import pytest +import nmrglue as ng +from numpy.testing import assert_array_equal +from lmfit import Model, Parameters + +from peakipy.io import Pseudo3D +from peakipy.fitting import ( + FitDataModel, + validate_fit_data, + validate_fit_dataframe, + select_reference_planes_using_indices, + slice_peaks_from_data_using_mask, + select_planes_above_threshold_from_masked_data, + get_limits_for_axis_in_points, + deal_with_peaks_on_edge_of_spectrum, + estimate_amplitude, + make_mask, + make_mask_from_peak_cluster, + make_meshgrid, + get_params, + fix_params, + make_param_dict, + to_prefix, + make_models, + PeakLimits, + update_params, + make_masks_from_plane_data, +) +from peakipy.lineshapes import Lineshape, pvoigt2d, pv_pv + + +@pytest.fixture +def fitdatamodel_dict(): + return FitDataModel( + plane=1, + clustid=1, + assignment="assignment", + memcnt=1, + amp=10.0, + height=10.0, + center_x_ppm=0.0, + center_y_ppm=0.0, + fwhm_x_hz=10.0, + fwhm_y_hz=10.0, + lineshape="PV", + x_radius=0.04, + y_radius=0.4, + center_x=0.0, + center_y=0.0, + sigma_x=1.0, + sigma_y=1.0, + ).model_dump() + + +def test_validate_fit_data_PVGL(fitdatamodel_dict): + fitdatamodel_dict.update(dict(fraction=0.5)) + validate_fit_data(fitdatamodel_dict) + + fitdatamodel_dict.update(dict(lineshape="G")) + validate_fit_data(fitdatamodel_dict) + + fitdatamodel_dict.update(dict(lineshape="L")) + validate_fit_data(fitdatamodel_dict) + + fitdatamodel_dict.update( + dict(lineshape="V", fraction=0.5, gamma_x=1.0, gamma_y=1.0) + ) + validate_fit_data(fitdatamodel_dict) + + fitdatamodel_dict.update(dict(lineshape="PVPV", fraction_x=0.5, fraction_y=1.0)) + validate_fit_data(fitdatamodel_dict) + + +def test_validate_fit_dataframe(fitdatamodel_dict): + fitdatamodel_dict.update(dict(fraction=0.5)) + df = pd.DataFrame([fitdatamodel_dict] * 5) + validate_fit_dataframe(df) + + +def test_select_reference_planes_using_indices(): + data = np.zeros((6, 100, 200)) + indices = [] + np.testing.assert_array_equal( + select_reference_planes_using_indices(data, indices), data + ) + indices = [1] + assert select_reference_planes_using_indices(data, indices).shape == (1, 100, 200) + indices = [1, -1] + assert select_reference_planes_using_indices(data, indices).shape == (2, 100, 200) + + +def test_select_reference_planes_using_indices_min_index_error(): + data = np.zeros((6, 100, 200)) + indices = [-7] + with pytest.raises(IndexError): + select_reference_planes_using_indices(data, indices) + + +def test_select_reference_planes_using_indices_max_index_error(): + data = np.zeros((6, 100, 200)) + indices = [6] + with pytest.raises(IndexError): + select_reference_planes_using_indices(data, indices) + + +def test_slice_peaks_from_data_using_mask(): + data = np.array( + [ + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 2, 1, 0, 0, 0], + [0, 0, 1, 2, 3, 3, 2, 1, 0, 0], + [0, 1, 2, 3, 4, 4, 3, 2, 1, 0], + [1, 2, 3, 4, 5, 5, 4, 3, 2, 1], + [0, 1, 2, 3, 4, 4, 3, 2, 1, 0], + [0, 0, 1, 2, 3, 3, 2, 1, 0, 0], + [0, 0, 0, 1, 2, 2, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + for i in range(5) + ] + ) + mask = data[0] > 0 + assert data.shape == (5, 11, 10) + assert mask.shape == (11, 10) + peak_slices = slice_peaks_from_data_using_mask(data, mask) + # array is flattened by application of mask + assert peak_slices.shape == (5, 50) + + +def test_select_planes_above_threshold_from_masked_data(): + peak_slices = np.array( + [ + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [-1, -1, -1, -1, -1, -1], + [-2, -2, -2, -2, -2, -2], + ] + ) + assert peak_slices.shape == (4, 6) + threshold = -1 + assert select_planes_above_threshold_from_masked_data( + peak_slices, threshold + ).shape == ( + 4, + 6, + ) + threshold = 2 + assert_array_equal( + select_planes_above_threshold_from_masked_data(peak_slices, threshold), + peak_slices, + ) + threshold = 1 + assert select_planes_above_threshold_from_masked_data( + peak_slices, threshold + ).shape == (2, 6) + + threshold = None + assert_array_equal( + select_planes_above_threshold_from_masked_data(peak_slices, threshold), + peak_slices, + ) + threshold = 10 + assert_array_equal( + select_planes_above_threshold_from_masked_data(peak_slices, threshold), + peak_slices, + ) + + +def test_make_param_dict(): + selected_planes = [1, 2] + data = np.ones((4, 10, 5)) + expected_shape = (2, 10, 5) + actual_shape = data[np.array(selected_planes)].shape + assert expected_shape == actual_shape + + +def test_make_param_dict_sum(): + data = np.ones((4, 10, 5)) + expected_sum = 200 + actual_sum = data.sum() + assert expected_sum == actual_sum + + +def test_make_param_dict_selected(): + selected_planes = [1, 2] + data = np.ones((4, 10, 5)) + data = data[np.array(selected_planes)] + expected_sum = 100 + actual_sum = data.sum() + assert expected_sum == actual_sum + + +def test_update_params_normal_case(): + params = Parameters() + params.add("center_x", value=0) + params.add("center_y", value=0) + params.add("sigma", value=1) + params.add("gamma", value=1) + params.add("fraction", value=0.5) + + param_dict = { + "center_x": 10, + "center_y": 20, + "sigma": 2, + "gamma": 3, + "fraction": 0.8, + } + + xy_bounds = (5, 5) + + update_params(params, param_dict, Lineshape.PV, xy_bounds) + + assert params["center_x"].value == 10 + assert params["center_y"].value == 20 + assert params["sigma"].value == 2 + assert params["gamma"].value == 3 + assert params["fraction"].value == 0.8 + assert params["center_x"].min == 5 + assert params["center_x"].max == 15 + assert params["center_y"].min == 15 + assert params["center_y"].max == 25 + assert params["sigma"].min == 0.0 + assert params["sigma"].max == 1e4 + assert params["gamma"].min == 0.0 + assert params["gamma"].max == 1e4 + assert params["fraction"].min == 0.0 + assert params["fraction"].max == 1.0 + assert params["fraction"].vary is True + + +def test_update_params_lineshape_G(): + params = Parameters() + params.add("fraction", value=0.5) + + param_dict = {"fraction": 0.7} + + update_params(params, param_dict, Lineshape.G) + + assert params["fraction"].value == 0.7 + assert params["fraction"].min == 0.0 + assert params["fraction"].max == 1.0 + assert params["fraction"].vary is False + + +def test_update_params_lineshape_L(): + params = Parameters() + params.add("fraction", value=0.5) + + param_dict = {"fraction": 0.7} + + update_params(params, param_dict, Lineshape.L) + + assert params["fraction"].value == 0.7 + assert params["fraction"].min == 0.0 + assert params["fraction"].max == 1.0 + assert params["fraction"].vary is False + + +def test_update_params_lineshape_PV_PV(): + params = Parameters() + params.add("fraction", value=0.5) + + param_dict = {"fraction": 0.7} + + update_params(params, param_dict, Lineshape.PV_PV) + + assert params["fraction"].value == 0.7 + assert params["fraction"].min == 0.0 + assert params["fraction"].max == 1.0 + assert params["fraction"].vary is True + + +def test_update_params_no_bounds(): + params = Parameters() + params.add("center_x", value=0) + params.add("center_y", value=0) + + param_dict = { + "center_x": 10, + "center_y": 20, + } + + update_params(params, param_dict, Lineshape.PV, None) + + assert params["center_x"].value == 10 + assert params["center_y"].value == 20 + assert params["center_x"].min == -np.inf + assert params["center_x"].max == np.inf + assert params["center_y"].min == -np.inf + assert params["center_y"].max == np.inf + + +def test_peak_limits_normal_case(): + peak = pd.DataFrame({"X_AXIS": [5], "Y_AXIS": [5], "XW": [2], "YW": [2]}).iloc[0] + data = np.zeros((10, 10)) + pl = PeakLimits(peak, data) + assert pl.min_x == 3 + assert pl.max_x == 8 + assert pl.min_y == 3 + assert pl.max_y == 8 + + +def test_peak_limits_at_edge(): + peak = pd.DataFrame({"X_AXIS": [1], "Y_AXIS": [1], "XW": [2], "YW": [2]}).iloc[0] + data = np.zeros((10, 10)) + pl = PeakLimits(peak, data) + assert pl.min_x == 0 + assert pl.max_x == 4 + assert pl.min_y == 0 + assert pl.max_y == 4 + + +def test_peak_limits_exceeding_bounds(): + peak = pd.DataFrame({"X_AXIS": [9], "Y_AXIS": [9], "XW": [2], "YW": [2]}).iloc[0] + data = np.zeros((10, 10)) + pl = PeakLimits(peak, data) + assert pl.min_x == 7 + assert pl.max_x == 10 + assert pl.min_y == 7 + assert pl.max_y == 10 + + +def test_peak_limits_small_data(): + peak = pd.DataFrame({"X_AXIS": [2], "Y_AXIS": [2], "XW": [5], "YW": [5]}).iloc[0] + data = np.zeros((5, 5)) + pl = PeakLimits(peak, data) + assert pl.min_x == 0 + assert pl.max_x == 5 + assert pl.min_y == 0 + assert pl.max_y == 5 + + +def test_peak_limits_assertion_error(): + peak = pd.DataFrame({"X_AXIS": [11], "Y_AXIS": [11], "XW": [2], "YW": [2]}).iloc[0] + data = np.zeros((10, 10)) + with pytest.raises(AssertionError): + pl = PeakLimits(peak, data) + + +def test_estimate_amplitude(): + peak = namedtuple("peak", ["X_AXIS", "XW", "Y_AXIS", "YW"]) + p = peak(5, 2, 3, 2) + data = np.ones((20, 10)) + expected_result = 25 + actual_result = estimate_amplitude(p, data) + assert expected_result == actual_result + + +def test_estimate_amplitude_invalid_indices(): + peak = namedtuple("peak", ["X_AXIS", "XW", "Y_AXIS", "YW"]) + p = peak(1, 2, 3, 2) + data = np.ones((20, 10)) + expected_result = 20 + actual_result = estimate_amplitude(p, data) + assert expected_result == actual_result + + +def test_make_mask_from_peak_cluster(): + data = np.ones((10, 10)) + group = pd.DataFrame( + {"X_AXISf": [3, 6], "Y_AXISf": [3, 6], "X_RADIUS": [2, 3], "Y_RADIUS": [2, 3]} + ) + mask, peak = make_mask_from_peak_cluster(group, data) + expected_mask = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ], + dtype=bool, + ) + assert_array_equal(expected_mask, mask) + + +# get_limits_for_axis_in_points +def test_positive_points(): + group_axis_points = np.array([1, 2, 3, 4, 5]) + mask_radius_in_points = 2 + expected = (8, -1) # ceil(5+1+1), floor(1-1) + assert ( + get_limits_for_axis_in_points(group_axis_points, mask_radius_in_points) + == expected + ) + + +def test_single_point(): + group_axis_points = np.array([5]) + mask_radius_in_points = 3 + expected = (9, 2) + assert ( + get_limits_for_axis_in_points(group_axis_points, mask_radius_in_points) + == expected + ) + + +def test_no_radius(): + group_axis_points = np.array([1, 2, 3]) + mask_radius_in_points = 0 + expected = (4, 1) + assert ( + get_limits_for_axis_in_points(group_axis_points, mask_radius_in_points) + == expected + ) + + +# deal_with_peaks_on_edge_of_spectrum +def test_min_y_less_than_zero(): + assert deal_with_peaks_on_edge_of_spectrum((100, 200), 50, 30, 10, -10) == ( + 50, + 30, + 10, + 0, + ) + + +def test_min_x_less_than_zero(): + assert deal_with_peaks_on_edge_of_spectrum((100, 200), 50, -5, 70, 20) == ( + 50, + 0, + 70, + 20, + ) + + +def test_max_y_exceeds_data_shape(): + assert deal_with_peaks_on_edge_of_spectrum((100, 200), 50, 30, 110, 20) == ( + 50, + 30, + 100, + 20, + ) + + +def test_max_x_exceeds_data_shape(): + assert deal_with_peaks_on_edge_of_spectrum((100, 200), 250, 30, 70, 20) == ( + 200, + 30, + 70, + 20, + ) + + +def test_values_within_range(): + assert deal_with_peaks_on_edge_of_spectrum((100, 200), 50, 30, 70, 20) == ( + 50, + 30, + 70, + 20, + ) + + +def test_all_edge_cases(): + assert deal_with_peaks_on_edge_of_spectrum((100, 200), 250, -5, 110, -10) == ( + 200, + 0, + 100, + 0, + ) + + +def test_make_meshgrid(): + data_shape = (4, 5) + expected_x = np.array( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + ) + expected_y = np.array( + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]] + ) + XY = make_meshgrid(data_shape) + np.testing.assert_array_equal(XY[0], expected_x) + np.testing.assert_array_equal(XY[1], expected_y) + + +class TestCoreFunctions(unittest.TestCase): + test_directory = Path(__file__).parent + test_directory = "./test" + + def test_make_mask(self): + data = np.ones((10, 10)) + c_x = 5 + c_y = 5 + r_x = 3 + r_y = 2 + + expected_result = np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + + result = np.array(make_mask(data, c_x, c_y, r_x, r_y), dtype=int) + test = result - expected_result + # print(test) + # print(test.sum()) + # print(result) + self.assertEqual(test.sum(), 0) + + def test_make_mask_2(self): + data = np.ones((10, 10)) + c_x = 5 + c_y = 8 + r_x = 3 + r_y = 2 + + expected_result = np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + ] + ) + + result = np.array(make_mask(data, c_x, c_y, r_x, r_y), dtype=int) + test = result - expected_result + # print(test) + # print(test.sum()) + # print(result) + self.assertEqual(test.sum(), 0) + + def test_fix_params(self): + mod = Model(pvoigt2d) + pars = mod.make_params() + to_fix = ["center", "sigma", "fraction"] + fix_params(pars, to_fix) + + self.assertEqual(pars["center_x"].vary, False) + self.assertEqual(pars["center_y"].vary, False) + self.assertEqual(pars["sigma_x"].vary, False) + self.assertEqual(pars["sigma_y"].vary, False) + self.assertEqual(pars["fraction"].vary, False) + + def test_get_params(self): + mod = Model(pvoigt2d, prefix="p1_") + pars = mod.make_params(p1_center_x=20.0, p1_center_y=30.0) + pars["p1_center_x"].stderr = 1.0 + pars["p1_center_y"].stderr = 2.0 + ps, ps_err, names, prefixes = get_params(pars, "center") + #  get index of values + cen_x = names.index("p1_center_x") + cen_y = names.index("p1_center_y") + + self.assertEqual(ps[cen_x], 20.0) + self.assertEqual(ps[cen_y], 30.0) + self.assertEqual(ps_err[cen_x], 1.0) + self.assertEqual(ps_err[cen_y], 2.0) + self.assertEqual(prefixes[cen_y], "p1_") + + def test_make_param_dict(self): + peaks = pd.DataFrame( + { + "ASS": ["one", "two", "three"], + "X_AXISf": [5.0, 10.0, 15.0], + "X_AXIS": [5, 10, 15], + "Y_AXISf": [15.0, 10.0, 5.0], + "Y_AXIS": [15, 10, 5], + "XW": [2.5, 2.5, 2.5], + "YW": [2.5, 2.5, 2.5], + } + ) + data = np.ones((20, 20)) + + for ls, frac in zip([Lineshape.PV, Lineshape.G, Lineshape.L], [0.5, 0.0, 1.0]): + params = make_param_dict(peaks, data, ls) + self.assertEqual(params["_one_fraction"], frac) + self.assertEqual(params["_two_fraction"], frac) + self.assertEqual(params["_three_fraction"], frac) + + self.assertEqual(params["_one_center_x"], 5.0) + self.assertEqual(params["_two_center_x"], 10.0) + self.assertEqual(params["_two_sigma_x"], 1.25) + self.assertEqual(params["_two_sigma_y"], 1.25) + + voigt_params = make_param_dict(peaks, data, Lineshape.V) + self.assertEqual( + voigt_params["_one_sigma_x"], 2.5 / (2.0 * np.sqrt(2.0 * np.log(2))) + ) + self.assertEqual(voigt_params["_one_gamma_x"], 2.5 / 2.0) + + def test_to_prefix(self): + names = [ + (1, "_1_"), + (1.0, "_1_0_"), + (" one", "_one_"), + (" one/two", "_oneortwo_"), + (" one?two", "_onemaybetwo_"), + (r" [{one?two\}][", "___onemaybetwo____"), + ] + for test, expect in names: + prefix = to_prefix(test) + # print(prefix) + self.assertEqual(prefix, expect) + + def test_make_models(self): + peaks = pd.DataFrame( + { + "ASS": ["one", "two", "three"], + "X_AXISf": [5.0, 10.0, 15.0], + "X_AXIS": [5, 10, 15], + "Y_AXISf": [15.0, 10.0, 5.0], + "Y_AXIS": [15, 10, 5], + "XW": [2.5, 2.5, 2.5], + "YW": [2.5, 2.5, 2.5], + "CLUSTID": [1, 1, 1], + } + ) + + group = peaks.groupby("CLUSTID") + + data = np.ones((20, 20)) + + lineshapes = [Lineshape.PV, Lineshape.L, Lineshape.G, Lineshape.PV_PV] + + for lineshape in lineshapes: + match lineshape: + case lineshape.PV: + mod, p_guess = make_models(pvoigt2d, peaks, data, lineshape) + self.assertEqual(p_guess["_one_fraction"].vary, True) + self.assertEqual(p_guess["_one_fraction"].value, 0.5) + + case lineshape.G: + mod, p_guess = make_models(pvoigt2d, peaks, data, lineshape) + self.assertEqual(p_guess["_one_fraction"].vary, False) + self.assertEqual(p_guess["_one_fraction"].value, 0.0) + + case lineshape.L: + mod, p_guess = make_models(pvoigt2d, peaks, data, lineshape) + self.assertEqual(p_guess["_one_fraction"].vary, False) + self.assertEqual(p_guess["_one_fraction"].value, 1.0) + + case lineshape.PV_PV: + mod, p_guess = make_models(pv_pv, peaks, data, lineshape) + self.assertEqual(p_guess["_one_fraction_x"].vary, True) + self.assertEqual(p_guess["_one_fraction_x"].value, 0.5) + self.assertEqual(p_guess["_one_fraction_y"].vary, True) + self.assertEqual(p_guess["_one_fraction_y"].value, 0.5) + + def test_Pseudo3D(self): + datasets = [ + (f"{self.test_directory}/test_protein_L/test1.ft2", [0, 1, 2]), + (f"{self.test_directory}/test_protein_L/test_tp.ft2", [2, 1, 0]), + (f"{self.test_directory}/test_protein_L/test_tp2.ft2", [1, 2, 0]), + ] + + # expected shape + data_shape = (4, 256, 546) + test_nu = 1 + for dataset, dims in datasets: + with self.subTest(i=test_nu): + dic, data = ng.pipe.read(dataset) + pseudo3D = Pseudo3D(dic, data, dims) + self.assertEqual(dims, pseudo3D.dims) + self.assertEqual(pseudo3D.data.shape, data_shape) + self.assertEqual(pseudo3D.f1_label, "15N") + self.assertEqual(pseudo3D.f2_label, "HN") + self.assertEqual(pseudo3D.dims, dims) + self.assertEqual(pseudo3D.f1_size, 256) + self.assertEqual(pseudo3D.f2_size, 546) + test_nu += 1 diff --git a/test/test_io.py b/test/test_io.py new file mode 100644 index 00000000..a945852d --- /dev/null +++ b/test/test_io.py @@ -0,0 +1,428 @@ +import unittest +from unittest.mock import patch +from pathlib import Path +import json + +import pytest +import numpy as np +import nmrglue as ng +import pandas as pd + +from peakipy.io import ( + Pseudo3D, + Peaklist, + LoadData, + PeaklistFormat, + OutFmt, + StrucEl, + UnknownFormat, + ClustersResult, + get_vclist, +) +from peakipy.fitting import PeakLimits +from peakipy.utils import load_config, write_config, update_config_file + + +@pytest.fixture +def test_directory(): + return Path(__file__).parent + + +# test for read, edit, fit, check and spec scripts +# need to actually write proper tests +class TestBokehScript(unittest.TestCase): + @patch("peakipy.cli.edit.BokehScript") + def test_BokehScript(self, MockBokehScript): + args = {"": "hello", "": "data"} + bokeh_plots = MockBokehScript(args) + self.assertIsNotNone(bokeh_plots) + + +class TestCheckScript(unittest.TestCase): + @patch("peakipy.cli.main.check") + def test_main(self, MockCheck): + args = {"": "hello", "": "data"} + check = MockCheck(args) + self.assertIsNotNone(check) + + +class TestFitScript(unittest.TestCase): + @patch("peakipy.cli.main.fit") + def test_main(self, MockFit): + args = {"": "hello", "": "data"} + fit = MockFit(args) + self.assertIsNotNone(fit) + + +class TestReadScript(unittest.TestCase): + test_directory = "./test/" + + @patch("peakipy.cli.main.read") + def test_main(self, MockRead): + args = {"": "hello", "": "data"} + read = MockRead(args) + self.assertIsNotNone(read) + + def test_read_pipe_peaklist(self): + args = { + "path": f"{self.test_directory}/test_pipe.tab", + "data_path": f"{self.test_directory}/test_pipe.ft2", + "dims": [0, 1, 2], + "fmt": PeaklistFormat.pipe, + } + peaklist = Peaklist(**args) + self.assertIsNotNone(peaklist) + self.assertIs(len(peaklist.df), 3) + # self.assertIs(peaklist.df.X_AXISf.iloc[0], 323.019) + self.assertIs(peaklist.fmt.value, "pipe") + # self.assertEqual(peaklist.df.ASS.iloc[0], "None") + # self.assertEqual(peaklist.df.ASS.iloc[1], "None_dummy_1") + + +class TestSpecScript(unittest.TestCase): + @patch("peakipy.cli.main.spec") + def test_main(self, MockSpec): + args = {"": "hello", "": "data"} + spec = MockSpec(args) + self.assertIsNotNone(spec) + + +def test_load_config_existing(): + config_path = Path("test_config.json") + # Create a dummy existing config file + with open(config_path, "w") as f: + json.dump({"key1": "value1"}, f) + + loaded_config = load_config(config_path) + + assert loaded_config == {"key1": "value1"} + + # Clean up + config_path.unlink() + + +def test_load_config_nonexistent(): + config_path = Path("test_config.json") + + loaded_config = load_config(config_path) + + assert loaded_config == {} + + +def test_write_config(): + config_path = Path("test_config.json") + config_kvs = {"key1": "value1", "key2": "value2"} + + write_config(config_path, config_kvs) + + # Check if the config file is created correctly + assert config_path.exists() + + # Check if the config file content is correct + with open(config_path) as f: + created_config = json.load(f) + assert created_config == {"key1": "value1", "key2": "value2"} + + # Clean up + config_path.unlink() + + +def test_update_config_file_existing(): + config_path = Path("test_config.json") + # Create a dummy existing config file + with open(config_path, "w") as f: + json.dump({"key1": "value1"}, f) + + config_kvs = {"key2": "value2", "key3": "value3"} + updated_config = update_config_file(config_path, config_kvs) + + assert updated_config == {"key1": "value1", "key2": "value2", "key3": "value3"} + + # Clean up + config_path.unlink() + + +def test_update_config_file_nonexistent(): + config_path = Path("test_config.json") + config_kvs = {"key1": "value1", "key2": "value2"} + updated_config = update_config_file(config_path, config_kvs) + + assert updated_config == {"key1": "value1", "key2": "value2"} + + # Clean up + config_path.unlink() + + +@pytest.fixture +def sample_data(): + return np.zeros((10, 10)) + + +@pytest.fixture +def sample_peak(): + peak_data = {"X_AXIS": [5], "Y_AXIS": [5], "XW": [2], "YW": [2]} + return pd.DataFrame(peak_data).iloc[0] + + +def test_peak_limits_max_min(sample_peak, sample_data): + limits = PeakLimits(sample_peak, sample_data) + + assert limits.max_x == 8 + assert limits.max_y == 8 + assert limits.min_x == 3 + assert limits.min_y == 3 + + +def test_peak_limits_boundary(sample_data): + peak_data = {"X_AXIS": [8], "Y_AXIS": [8], "XW": [2], "YW": [2]} + peak = pd.DataFrame(peak_data).iloc[0] + limits = PeakLimits(peak, sample_data) + + assert limits.max_x == 10 + assert limits.max_y == 10 + assert limits.min_x == 6 + assert limits.min_y == 6 + + +def test_peak_limits_at_boundary(sample_data): + peak_data = {"X_AXIS": [0], "Y_AXIS": [0], "XW": [2], "YW": [2]} + peak = pd.DataFrame(peak_data).iloc[0] + limits = PeakLimits(peak, sample_data) + + assert limits.max_x == 3 + assert limits.max_y == 3 + assert limits.min_x == 0 + assert limits.min_y == 0 + + +def test_peak_limits_outside_boundary(sample_data): + peak_data = {"X_AXIS": [15], "Y_AXIS": [15], "XW": [2], "YW": [2]} + peak = pd.DataFrame(peak_data).iloc[0] + with pytest.raises(AssertionError): + limits = PeakLimits(peak, sample_data) + + +def test_peak_limits_1d_data(): + data = np.zeros(10) + peak_data = {"X_AXIS": [5], "Y_AXIS": [0], "XW": [2], "YW": [0]} + peak = pd.DataFrame(peak_data).iloc[0] + with pytest.raises(IndexError): + limits = PeakLimits(peak, data) + + +def test_StrucEl(): + assert StrucEl.square.value == "square" + assert StrucEl.disk.value == "disk" + assert StrucEl.rectangle.value == "rectangle" + assert StrucEl.mask_method.value == "mask_method" + + +def test_PeaklistFormat(): + assert PeaklistFormat.a2.value == "a2" + assert PeaklistFormat.a3.value == "a3" + assert PeaklistFormat.sparky.value == "sparky" + assert PeaklistFormat.pipe.value == "pipe" + assert PeaklistFormat.peakipy.value == "peakipy" + + +def test_OutFmt(): + assert OutFmt.csv.value == "csv" + assert OutFmt.pkl.value == "pkl" + + +@pytest.fixture +def test_data_path(): + return Path("./test/test_protein_L") + + +@pytest.fixture +def pseudo3d_args(test_data_path): + dic, data = ng.pipe.read(test_data_path / "test1.ft2") + dims = [0, 1, 2] + return dic, data, dims + + +@pytest.fixture +def peaklist(test_data_path): + dims = [0, 1, 2] + path = test_data_path / "test.tab" + data_path = test_data_path / "test1.ft2" + fmt = PeaklistFormat.pipe + radii = [0.04, 0.4] + return Peaklist(path, data_path, fmt, dims, radii) + + +def test_Pseudo3D_properties(pseudo3d_args): + dic, data, dims = pseudo3d_args + pseudo3d = Pseudo3D(dic, data, dims) + assert pseudo3d.dic == dic + assert np.array_equal(pseudo3d._data, data.reshape((4, 256, 546))) + assert pseudo3d.dims == dims + + +def test_Peaklist_initialization(test_data_path, peaklist): + + assert peaklist.peaklist_path == test_data_path / "test.tab" + assert peaklist.data_path == test_data_path / "test1.ft2" + assert peaklist.fmt == PeaklistFormat.pipe + assert peaklist.radii == [0.04, 0.4] + + +def test_Peaklist_a2(test_data_path): + dims = [0, 1, 2] + path = test_data_path / "peaks.a2" + data_path = test_data_path / "test1.ft2" + fmt = PeaklistFormat.a2 + radii = [0.04, 0.4] + peaklist = Peaklist(path, data_path, fmt, dims, radii) + peaklist.update_df() + + +def test_Peaklist_a3(test_data_path): + dims = [0, 1, 2] + path = test_data_path / "ccpnTable.tsv" + data_path = test_data_path / "test1.ft2" + fmt = PeaklistFormat.a3 + radii = [0.04, 0.4] + peaklist = Peaklist(path, data_path, fmt, dims, radii) + peaklist.update_df() + + +def test_Peaklist_sparky(test_data_path): + dims = [0, 1, 2] + path = test_data_path / "peaks.sparky" + data_path = test_data_path / "test1.ft2" + fmt = PeaklistFormat.sparky + radii = [0.04, 0.4] + Peaklist(path, data_path, fmt, dims, radii) + + +@pytest.fixture +def loaddata(test_data_path): + dims = [0, 1, 2] + path = test_data_path / "test.csv" + data_path = test_data_path / "test1.ft2" + fmt = PeaklistFormat.peakipy + radii = [0.04, 0.4] + return LoadData(path, data_path, fmt, dims, radii) + + +def test_LoadData_initialization(test_data_path, loaddata): + assert loaddata.peaklist_path == test_data_path / "test.csv" + assert loaddata.data_path == test_data_path / "test1.ft2" + assert loaddata.fmt == PeaklistFormat.peakipy + assert loaddata.radii == [0.04, 0.4] + loaddata.check_data_frame() + loaddata.check_assignments() + loaddata.check_peak_bounds() + loaddata.update_df() + + +def test_LoadData_with_Edited_column(loaddata): + loaddata.df["Edited"] = "yes" + loaddata.check_data_frame() + + +def test_LoadData_without_include_column(loaddata): + loaddata.df.drop(columns=["include"], inplace=True) + loaddata.check_data_frame() + assert "include" in loaddata.df.columns + assert np.all(loaddata.df.include == "yes") + + +def test_LoadData_with_X_DIAMETER_PPM_column(loaddata): + loaddata.df["X_DIAMETER_PPM"] = 0.04 + loaddata.check_data_frame() + assert "X_DIAMETER_PPM" in loaddata.df.columns + + +def test_UnknownFormat(): + with pytest.raises(UnknownFormat): + raise UnknownFormat("This is an unknown format") + + +def test_update_df(peaklist): + peaklist.update_df() + + df = peaklist.df + + # Check that X_AXIS and Y_AXIS columns are created and values are set correctly + assert "X_AXIS" in df.columns + assert "Y_AXIS" in df.columns + + # Check that X_AXISf and Y_AXISf columns are created and values are set correctly + assert "X_AXISf" in df.columns + assert "Y_AXISf" in df.columns + + # Check that XW_HZ and YW_HZ columns are converted to float correctly + assert df["XW_HZ"].dtype == float + assert df["YW_HZ"].dtype == float + + # Check that XW and YW columns are created + assert "XW" in df.columns + assert "YW" in df.columns + + # Check the assignment column + assert "ASS" in df.columns + + # Check radii columns + assert "X_RADIUS_PPM" in df.columns + assert "Y_RADIUS_PPM" in df.columns + assert "X_RADIUS" in df.columns + assert "Y_RADIUS" in df.columns + + # Check 'include' column is created and set to 'yes' + assert "include" in df.columns + assert all(df["include"] == "yes") + + # Check that the peaks are within bounds + assert all( + (df["X_PPM"] < peaklist.f2_ppm_max) & (df["X_PPM"] > peaklist.f2_ppm_min) + ) + assert all( + (df["Y_PPM"] < peaklist.f1_ppm_max) & (df["Y_PPM"] > peaklist.f1_ppm_min) + ) + + +def test_update_df_with_excluded_peaks(peaklist): + peaklist._df.loc[1, "X_PPM"] = 100.0 # This peak should be out of bounds + peaklist.update_df() + + df = peaklist.df + + # Check that out of bounds peak is excluded + assert len(df) == 62 + assert not ((df["X_PPM"] == 100.0).any()) + + +def test_clusters_result_initialization(): + labeled_array = np.array([[1, 2], [3, 4]]) + num_features = 5 + closed_data = np.array([[5, 6], [7, 8]]) + peaks = [(1, 2), (3, 4)] + + clusters_result = ClustersResult(labeled_array, num_features, closed_data, peaks) + + assert np.array_equal(clusters_result.labeled_array, labeled_array) + assert clusters_result.num_features == num_features + assert np.array_equal(clusters_result.closed_data, closed_data) + assert clusters_result.peaks == peaks + + +def test_get_vclist_None(): + assert get_vclist(None, {})["vclist"] == False + + +def test_get_vclist_exists(test_data_path): + vclist = test_data_path / "vclist" + assert get_vclist(vclist, {})["vclist"] == True + + +def test_get_vclist_not_exists(test_data_path): + vclist = test_data_path / "vclistbla" + with pytest.raises(Exception): + get_vclist(vclist, {})["vclist"] == True + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_lineshapes.py b/test/test_lineshapes.py new file mode 100644 index 00000000..43e867c9 --- /dev/null +++ b/test/test_lineshapes.py @@ -0,0 +1,353 @@ +from pathlib import Path +from unittest.mock import Mock + +import pytest +import pandas as pd +import numpy as np +from numpy.testing import assert_almost_equal + +from peakipy.io import Peaklist, PeaklistFormat +from peakipy.constants import tiny +from peakipy.lineshapes import ( + gaussian, + gaussian_lorentzian, + pv_g, + pv_l, + voigt2d, + pvoigt2d, + pv_pv, + get_lineshape_function, + Lineshape, + calculate_height_for_voigt_lineshape, + calculate_fwhm_for_voigt_lineshape, + calculate_fwhm_for_pseudo_voigt_lineshape, + calculate_height_for_pseudo_voigt_lineshape, + calculate_height_for_gaussian_lineshape, + calculate_height_for_lorentzian_lineshape, + calculate_height_for_pv_pv_lineshape, + calculate_lineshape_specific_height_and_fwhm, + calculate_peak_linewidths_in_hz, + calculate_peak_centers_in_ppm, +) + + +def test_gaussian_typical_values(): + x = np.array([0, 1, 2]) + center = 0.0 + sigma = 1.0 + expected = (1.0 / (np.sqrt(2 * np.pi) * sigma)) * np.exp( + -((x - center) ** 2) / (2 * sigma**2) + ) + result = gaussian(x, center, sigma) + assert_almost_equal(result, expected, decimal=7) + + +def test_gaussian_center_nonzero(): + x = np.array([0, 1, 2]) + center = 1.0 + sigma = 1.0 + expected = (1.0 / (np.sqrt(2 * np.pi) * sigma)) * np.exp( + -((x - center) ** 2) / (2 * sigma**2) + ) + result = gaussian(x, center, sigma) + assert_almost_equal(result, expected, decimal=7) + + +def test_gaussian_sigma_nonzero(): + x = np.array([0, 1, 2]) + center = 0.0 + sigma = 2.0 + expected = (1.0 / (np.sqrt(2 * np.pi) * sigma)) * np.exp( + -((x - center) ** 2) / (2 * sigma**2) + ) + result = gaussian(x, center, sigma) + assert_almost_equal(result, expected, decimal=7) + + +def test_gaussian_zero_center(): + x = np.array([0, 1, 2]) + center = 0.0 + sigma = 1.0 + expected = (1.0 / (np.sqrt(2 * np.pi) * sigma)) * np.exp( + -((x - center) ** 2) / (2 * sigma**2) + ) + result = gaussian(x, center, sigma) + assert_almost_equal(result, expected, decimal=7) + + +def test_calculate_height_for_voigt_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + } + df = pd.DataFrame(data) + result_df = calculate_height_for_voigt_lineshape(df) + + assert np.allclose(result_df["height"], [0.435596, 0.217798]) + assert np.allclose(result_df["height_err"], [0.04356, 0.02178]) + + +def test_calculate_fwhm_for_voigt_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + } + df = pd.DataFrame(data) + result_df = calculate_fwhm_for_voigt_lineshape(df) + + assert np.allclose(result_df["fwhm_l_x"], [2.0, 4.0]) + assert np.allclose(result_df["fwhm_l_y"], [2.0, 4.0]) + assert np.allclose(result_df["fwhm_g_x"], [2.35482, 4.70964]) + assert np.allclose(result_df["fwhm_g_y"], [2.35482, 4.70964]) + assert np.allclose(result_df["fwhm_x"], [3.601309, 7.202619]) + assert np.allclose(result_df["fwhm_y"], [3.601309, 7.202619]) + + +def test_calculate_height_for_pseudo_voigt_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction": [0.5, 0.5], + } + df = pd.DataFrame(data) + result_df = calculate_height_for_pseudo_voigt_lineshape(df) + + assert np.allclose(result_df["height"], [1.552472, 0.776236]) + assert np.allclose(result_df["height_err"], [0.155247, 0.077624]) + + +def test_calculate_fwhm_for_pseudo_voigt_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction": [0.5, 0.5], + } + df = pd.DataFrame(data) + result_df = calculate_fwhm_for_pseudo_voigt_lineshape(df) + + assert np.allclose(result_df["fwhm_x"], [2.0, 4.0]) + assert np.allclose(result_df["fwhm_y"], [2.0, 4.0]) + + +def test_calculate_height_for_gaussian_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction": [0.5, 0.5], + } + df = pd.DataFrame(data) + result_df = calculate_height_for_gaussian_lineshape(df) + + assert np.allclose(result_df["height"], [2.206356, 1.103178]) + assert np.allclose(result_df["height_err"], [0.220636, 0.110318]) + + +def test_calculate_height_for_lorentzian_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction": [0.5, 0.5], + } + df = pd.DataFrame(data) + result_df = calculate_height_for_lorentzian_lineshape(df) + + assert np.allclose(result_df["height"], [1.013212, 0.506606]) + assert np.allclose(result_df["height_err"], [0.101321, 0.050661]) + + +def test_calculate_height_for_pv_pv_lineshape(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction_x": [0.5, 0.5], + "fraction_y": [0.5, 0.5], + } + df = pd.DataFrame(data) + result_df = calculate_height_for_pv_pv_lineshape(df) + + assert np.allclose(result_df["height"], [1.552472, 0.776236]) + assert np.allclose(result_df["height_err"], [0.155247, 0.077624]) + + +def test_calculate_height_for_pv_pv_lineshape_fraction_y(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction_x": [0.5, 0.5], + "fraction_y": [1.0, 1.0], + } + df = pd.DataFrame(data) + result_df = calculate_height_for_pv_pv_lineshape(df) + + assert np.allclose(result_df["height"], [1.254186, 0.627093]) + assert np.allclose(result_df["height_err"], [0.125419, 0.062709]) + + +def test_calculate_lineshape_specific_height_and_fwhm(): + data = { + "sigma_x": [1.0, 2.0], + "sigma_y": [1.0, 2.0], + "gamma_x": [1.0, 2.0], + "gamma_y": [1.0, 2.0], + "amp": [10.0, 20.0], + "amp_err": [1.0, 2.0], + "fraction": [0.5, 0.5], + "fraction_x": [0.5, 0.5], + "fraction_y": [0.5, 0.5], + } + df = pd.DataFrame(data) + calculate_lineshape_specific_height_and_fwhm(Lineshape.G, df) + calculate_lineshape_specific_height_and_fwhm(Lineshape.L, df) + calculate_lineshape_specific_height_and_fwhm(Lineshape.V, df) + calculate_lineshape_specific_height_and_fwhm(Lineshape.PV, df) + calculate_lineshape_specific_height_and_fwhm(Lineshape.PV_PV, df) + calculate_lineshape_specific_height_and_fwhm(Lineshape.PV_G, df) + calculate_lineshape_specific_height_and_fwhm(Lineshape.PV_L, df) + + +def test_get_lineshape_function(): + assert get_lineshape_function(Lineshape.PV) == pvoigt2d + assert get_lineshape_function(Lineshape.L) == pvoigt2d + assert get_lineshape_function(Lineshape.G) == pvoigt2d + assert get_lineshape_function(Lineshape.G_L) == gaussian_lorentzian + assert get_lineshape_function(Lineshape.PV_G) == pv_g + assert get_lineshape_function(Lineshape.PV_L) == pv_l + assert get_lineshape_function(Lineshape.PV_PV) == pv_pv + assert get_lineshape_function(Lineshape.V) == voigt2d + + +def test_get_lineshape_function_exception(): + with pytest.raises(Exception): + get_lineshape_function("bla") + + +@pytest.fixture +def peakipy_data(): + test_data_path = Path("./test/test_protein_L/") + return Peaklist( + test_data_path / "test.tab", test_data_path / "test1.ft2", PeaklistFormat.pipe + ) + + +def test_calculate_peak_linewidths_in_hz(): + # Sample data for testing + data = { + "sigma_x": [1.0, 2.0, 3.0], + "sigma_y": [1.5, 2.5, 3.5], + "fwhm_x": [0.5, 1.5, 2.5], + "fwhm_y": [0.7, 1.7, 2.7], + } + df = pd.DataFrame(data) + + # Mock peakipy_data object + peakipy_data = Mock() + peakipy_data.ppm_per_pt_f2 = 0.01 + peakipy_data.ppm_per_pt_f1 = 0.02 + peakipy_data.hz_per_pt_f2 = 10.0 + peakipy_data.hz_per_pt_f1 = 20.0 + + # Expected results + expected_sigma_x_ppm = [0.01, 0.02, 0.03] + expected_sigma_y_ppm = [0.03, 0.05, 0.07] + expected_fwhm_x_ppm = [0.005, 0.015, 0.025] + expected_fwhm_y_ppm = [0.014, 0.034, 0.054] + expected_fwhm_x_hz = [5.0, 15.0, 25.0] + expected_fwhm_y_hz = [14.0, 34.0, 54.0] + + # Run the function + result_df = calculate_peak_linewidths_in_hz(df, peakipy_data) + + # Assertions + pd.testing.assert_series_equal( + result_df["sigma_x_ppm"], pd.Series(expected_sigma_x_ppm), check_names=False + ) + pd.testing.assert_series_equal( + result_df["sigma_y_ppm"], pd.Series(expected_sigma_y_ppm), check_names=False + ) + pd.testing.assert_series_equal( + result_df["fwhm_x_ppm"], pd.Series(expected_fwhm_x_ppm), check_names=False + ) + pd.testing.assert_series_equal( + result_df["fwhm_y_ppm"], pd.Series(expected_fwhm_y_ppm), check_names=False + ) + pd.testing.assert_series_equal( + result_df["fwhm_x_hz"], pd.Series(expected_fwhm_x_hz), check_names=False + ) + pd.testing.assert_series_equal( + result_df["fwhm_y_hz"], pd.Series(expected_fwhm_y_hz), check_names=False + ) + + +def test_calculate_peak_centers_in_ppm(): + # Sample data for testing + data = { + "center_x": [10, 20, 30], + "center_y": [15, 25, 35], + "init_center_x": [12, 22, 32], + "init_center_y": [18, 28, 38], + } + df = pd.DataFrame(data) + + # Mock peakipy_data object + peakipy_data = Mock() + peakipy_data.uc_f2.ppm = Mock(side_effect=lambda x: x * 0.1) + peakipy_data.uc_f1.ppm = Mock(side_effect=lambda x: x * 0.2) + + # Expected results + expected_center_x_ppm = [1.0, 2.0, 3.0] + expected_center_y_ppm = [3.0, 5.0, 7.0] + expected_init_center_x_ppm = [1.2, 2.2, 3.2] + expected_init_center_y_ppm = [3.6, 5.6, 7.6] + + # Run the function + result_df = calculate_peak_centers_in_ppm(df, peakipy_data) + + # Assertions + pd.testing.assert_series_equal( + result_df["center_x_ppm"], pd.Series(expected_center_x_ppm), check_names=False + ) + pd.testing.assert_series_equal( + result_df["center_y_ppm"], pd.Series(expected_center_y_ppm), check_names=False + ) + pd.testing.assert_series_equal( + result_df["init_center_x_ppm"], + pd.Series(expected_init_center_x_ppm), + check_names=False, + ) + pd.testing.assert_series_equal( + result_df["init_center_y_ppm"], + pd.Series(expected_init_center_y_ppm), + check_names=False, + ) diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..baf9f8da --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,252 @@ +from unittest.mock import patch, mock_open, MagicMock +from datetime import datetime +import json +import os +import tempfile +from pathlib import Path + +import pytest +import pandas as pd + +# Assuming the run_log function is defined in a module named 'log_module' +from peakipy.utils import ( + run_log, + update_args_with_values_from_config_file, + update_peak_positions_from_ppm_to_points, + update_linewidths_from_hz_to_points, + save_data, +) + + +@patch("peakipy.utils.open", new_callable=mock_open) +@patch("peakipy.utils.datetime") +@patch("peakipy.utils.sys") +def test_run_log(mock_sys, mock_datetime, mock_open_file): + # Mocking sys.argv + mock_sys.argv = ["test_script.py", "arg1", "arg2"] + + # Mocking datetime to return a fixed timestamp + fixed_timestamp = datetime(2024, 5, 20, 15, 45) + mock_datetime.now.return_value = fixed_timestamp + + # Expected timestamp string + expected_time_stamp = fixed_timestamp.strftime("%A %d %B %Y at %H:%M") + + # Run the function + run_log("mock_run_log.txt") + + # Prepare the expected log content + expected_log_content = ( + f"# Script run on {expected_time_stamp}:\ntest_script.py arg1 arg2\n" + ) + + # Assert that the file was opened correctly + mock_open_file.assert_called_once_with("mock_run_log.txt", "a") + + # Assert that the correct content was written to the file + mock_open_file().write.assert_called_once_with(expected_log_content) + + # Assert that the script name is correctly set to the basename + assert mock_sys.argv[0] == "test_script.py" + + +# Mock configuration loader function (you need to replace 'config_module.load_config' with the actual path if different) +@patch("peakipy.utils.load_config") +@patch("peakipy.utils.Path.exists") +def test_update_args_with_config(mock_path_exists, mock_load_config): + # Test setup + mock_path_exists.return_value = True # Pretend the config file exists + mock_load_config.return_value = { + "dims": [1, 2, 3], + "noise": "0.05", + "colors": ["#ff0000", "#00ff00"], + } + + args = {"dims": (0, 1, 2), "noise": False, "colors": ["#5e3c99", "#e66101"]} + + # Run the function + updated_args, config = update_args_with_values_from_config_file(args) + + # Check the updates to args + assert updated_args["dims"] == [1, 2, 3] + assert updated_args["noise"] == 0.05 + assert updated_args["colors"] == ["#ff0000", "#00ff00"] + + # Check the returned config + assert config == { + "dims": [1, 2, 3], + "noise": "0.05", + "colors": ["#ff0000", "#00ff00"], + } + + +@patch("peakipy.utils.Path.exists") +def test_update_args_with_no_config_file(mock_path_exists): + # Test setup + mock_path_exists.return_value = False # Pretend the config file does not exist + + args = {"dims": (0, 1, 2), "noise": False, "colors": ["#5e3c99", "#e66101"]} + + # Run the function + updated_args, config = update_args_with_values_from_config_file(args) + + # Check the updates to args + assert updated_args["dims"] == (0, 1, 2) + assert updated_args["noise"] == False + assert updated_args["colors"] == ["#5e3c99", "#e66101"] + + # Check the returned config (should be empty) + assert config == {} + + +@patch("peakipy.utils.load_config") +@patch("peakipy.utils.Path.exists") +def test_update_args_with_corrupt_config_file(mock_path_exists, mock_load_config): + # Test setup + mock_path_exists.return_value = True # Pretend the config file exists + mock_load_config.side_effect = json.decoder.JSONDecodeError( + "Expecting value", "", 0 + ) # Simulate corrupt JSON + + args = {"dims": (0, 1, 2), "noise": False, "colors": ["#5e3c99", "#e66101"]} + + # Run the function + updated_args, config = update_args_with_values_from_config_file(args) + + # Check the updates to args + assert updated_args["dims"] == (0, 1, 2) + assert updated_args["noise"] == False + assert updated_args["colors"] == ["#5e3c99", "#e66101"] + + # Check the returned config (should be empty due to error) + assert config == {} + + # Mock class to simulate the peakipy_data object + + +class MockPeakipyData: + def __init__(self, df, pt_per_hz_f2, pt_per_hz_f1, uc_f2, uc_f1): + self.df = df + self.pt_per_hz_f2 = pt_per_hz_f2 + self.pt_per_hz_f1 = pt_per_hz_f1 + self.uc_f2 = uc_f2 + self.uc_f1 = uc_f1 + + +# Test data +@pytest.fixture +def mock_peakipy_data(): + df = pd.DataFrame( + { + "XW_HZ": [10, 20, 30], + "YW_HZ": [5, 15, 25], + "X_PPM": [1.0, 2.0, 3.0], + "Y_PPM": [0.5, 1.5, 2.5], + } + ) + + pt_per_hz_f2 = 2.0 + pt_per_hz_f1 = 3.0 + + uc_f2 = MagicMock() + uc_f1 = MagicMock() + uc_f2.side_effect = lambda x, unit: x * 100.0 if unit == "PPM" else x + uc_f1.side_effect = lambda x, unit: x * 200.0 if unit == "PPM" else x + uc_f2.f = MagicMock(side_effect=lambda x, unit: x * 1000.0 if unit == "PPM" else x) + uc_f1.f = MagicMock(side_effect=lambda x, unit: x * 2000.0 if unit == "PPM" else x) + + return MockPeakipyData(df, pt_per_hz_f2, pt_per_hz_f1, uc_f2, uc_f1) + + +def test_update_linewidths_from_hz_to_points(mock_peakipy_data): + peakipy_data = update_linewidths_from_hz_to_points(mock_peakipy_data) + + expected_XW = [20.0, 40.0, 60.0] + expected_YW = [15.0, 45.0, 75.0] + + pd.testing.assert_series_equal( + peakipy_data.df["XW"], pd.Series(expected_XW, name="XW") + ) + pd.testing.assert_series_equal( + peakipy_data.df["YW"], pd.Series(expected_YW, name="YW") + ) + + +def test_update_peak_positions_from_ppm_to_points(mock_peakipy_data): + peakipy_data = update_peak_positions_from_ppm_to_points(mock_peakipy_data) + + expected_X_AXIS = [100.0, 200.0, 300.0] + expected_Y_AXIS = [100.0, 300.0, 500.0] + expected_X_AXISf = [1000.0, 2000.0, 3000.0] + expected_Y_AXISf = [1000.0, 3000.0, 5000.0] + + pd.testing.assert_series_equal( + peakipy_data.df["X_AXIS"], pd.Series(expected_X_AXIS, name="X_AXIS") + ) + pd.testing.assert_series_equal( + peakipy_data.df["Y_AXIS"], pd.Series(expected_Y_AXIS, name="Y_AXIS") + ) + pd.testing.assert_series_equal( + peakipy_data.df["X_AXISf"], pd.Series(expected_X_AXISf, name="X_AXISf") + ) + pd.testing.assert_series_equal( + peakipy_data.df["Y_AXISf"], pd.Series(expected_Y_AXISf, name="Y_AXISf") + ) + + +@pytest.fixture +def sample_dataframe(): + data = {"A": [1, 2, 3], "B": [4.5678, 5.6789, 6.7890]} + return pd.DataFrame(data) + + +def test_save_data_csv(sample_dataframe): + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmpfile: + output_name = Path(tmpfile.name) + + try: + save_data(sample_dataframe, output_name) + + assert output_name.exists() + + # Load the CSV and compare with the original dataframe + loaded_df = pd.read_csv(output_name) + pd.testing.assert_frame_equal( + loaded_df, sample_dataframe, check_exact=False, rtol=1e-4 + ) + finally: + os.remove(output_name) + + +def test_save_data_tab(sample_dataframe): + with tempfile.NamedTemporaryFile(suffix=".tab", delete=False) as tmpfile: + output_name = Path(tmpfile.name) + + try: + save_data(sample_dataframe, output_name) + + assert output_name.exists() + + # Load the tab-separated file and compare with the original dataframe + loaded_df = pd.read_csv(output_name, sep="\t") + pd.testing.assert_frame_equal( + loaded_df, sample_dataframe, check_exact=False, rtol=1e-4 + ) + finally: + os.remove(output_name) + + +def test_save_data_pickle(sample_dataframe): + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmpfile: + output_name = Path(tmpfile.name) + + try: + save_data(sample_dataframe, output_name) + + assert output_name.exists() + + # Load the pickle file and compare with the original dataframe + loaded_df = pd.read_pickle(output_name) + pd.testing.assert_frame_equal(loaded_df, sample_dataframe) + finally: + os.remove(output_name)