diff --git a/CHANGELOG.md b/CHANGELOG.md index 5baad301..72aacd84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe) - Implement a converter for lattice files imported from Elegant (see #222, #251, #273, #281) (@hespe, @jank324) - `Beam` and `Element` objects now have a `.clone()` method to create a deep copy (see #289) (@hespe, @jank324) +- `ParticleBeam` now comes with methods for plotting the beam distribution in a variety of ways (see #292) (@roussel-ryan, @jank324) ### 🐛 Bug fixes diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index adf8a7ec..1d8d04be 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -1,13 +1,19 @@ -from typing import Optional, Union +import itertools +from typing import List, Literal, Optional, Tuple, Union +import numpy as np import torch +from matplotlib import pyplot as plt from scipy import constants from scipy.constants import physical_constants +from scipy.ndimage import gaussian_filter from torch.distributions import MultivariateNormal from cheetah.particles.beam import Beam from cheetah.utils import ( elementwise_linspace, + format_axis_as_percentage, + format_axis_with_prefixed_unit, unbiased_weighted_covariance, unbiased_weighted_std, verify_device_and_dtype, @@ -35,6 +41,15 @@ class ParticleBeam(Beam): :param dtype: Data type of the generated particles. """ + PRETTY_DIMENSION_LABELS = { + "x": r"$x$", + "px": r"$p_x$", + "y": r"$y$", + "py": r"$p_y$", + "tau": r"$\tau$", + "p": r"$\delta$", + } + def __init__( self, particles: torch.Tensor, @@ -883,6 +898,330 @@ def to_xyz_pxpypz(self) -> torch.Tensor: return xp_coords + def plot_1d_distribution( + self, + dimension: Literal["x", "px", "y", "py", "tau", "p"], + bins: int = 100, + bin_range: Optional[Tuple[float]] = None, + smoothing: float = 0.0, + plot_kws: Optional[dict] = None, + ax: Optional[plt.Axes] = None, + ) -> plt.Axes: + """ + Plot a 1D histogram of the given dimension of the particle distribution. + + :param dimension: Name of the dimension to plot. Should be one of + `('x', 'px', 'y', 'py', 'tau', 'p')`. + :param bins: Number of bins to use for the histogram. + :param bin_range: Range of the bins to use for the histogram. + :param smoothing: Standard deviation of the Gaussian kernel used to smooth the + histogram. + :param plot_kws: Additional keyword arguments to be passed to `plot` function of + matplotlib used to plot the histogram data. + :param ax: Matplotlib axes object to use for plotting. + :return: Matplotlib axes object with the plot. + """ + if ax is None: + _, ax = plt.subplots() + + x_array = getattr(self, dimension).cpu().detach().numpy() + histogram, edges = np.histogram(x_array, bins=bins, range=bin_range) + centers = (edges[:-1] + edges[1:]) / 2 + + if smoothing: + histogram = gaussian_filter(histogram, smoothing) + + ax.plot( + centers, + histogram / histogram.max(), + **{"color": "black"} | (plot_kws or {}), + ) + ax.set_xlabel(f"{self.PRETTY_DIMENSION_LABELS[dimension]}") + + # Handle units + if dimension in ("x", "y", "tau"): + base_unit = "m" + elif dimension in ("px", "py", "p"): + base_unit = "%" + + if dimension in ("x", "y", "tau"): + format_axis_with_prefixed_unit(ax.xaxis, base_unit, centers) + elif dimension in ("px", "py", "p"): + format_axis_as_percentage(ax.xaxis) + + return ax + + def plot_2d_distribution( + self, + x_dimension: Literal["x", "px", "y", "py", "tau", "p"], + y_dimension: Literal["x", "px", "y", "py", "tau", "p"], + contour: bool = False, + bins: int = 100, + bin_ranges: Optional[Tuple[Tuple[float]]] = None, + histogram_smoothing: float = 0.0, + contour_smoothing: float = 3.0, + pcolormesh_kws: Optional[dict] = None, + contour_kws: Optional[dict] = None, + ax: Optional[plt.Axes] = None, + ) -> plt.Axes: + """ + Plot a 2D histogram of the given dimensions of the particle distribution. + + :param x_dimension: Name of the x dimension to plot. Should be one of + `('x', 'px', 'y', 'py', 'tau', 'p')`. + :param y_dimension: Name of the y dimension to plot. Should be one of + `('x', 'px', 'y', 'py', 'tau', 'p')`. + :param contour: If `True`, overlay contour lines on the 2D histogram plot. + :param bins: Number of bins to use for the histogram in both dimensions. + :param bin_ranges: Ranges of the bins to use for the histogram in each + dimension. + :param smoothing: Standard deviation of the Gaussian kernel used to smooth the + histogram. + :param pcolormesh_kws: Additional keyword arguments to be passed to `pcolormesh` + function of matplotlib used to plot the histogram data. + :param contour_kws: Additional keyword arguments to be passed to `contour` + function of matplotlib used to plot the histogram data. + :param ax: Matplotlib axes object to use for plotting. + :return: Matplotlib axes object with the plot. + """ + if ax is None: + _, ax = plt.subplots() + + histogram, x_edges, y_edges = np.histogram2d( + getattr(self, x_dimension).cpu().detach().numpy(), + getattr(self, y_dimension).cpu().detach().numpy(), + bins=bins, + range=bin_ranges, + ) + x_centers = (x_edges[:-1] + x_edges[1:]) / 2 + y_centers = (y_edges[:-1] + y_edges[1:]) / 2 + + # Post-process and plot + smoothed_histogram = gaussian_filter(histogram, histogram_smoothing) + clipped_histogram = np.where(smoothed_histogram > 1, smoothed_histogram, np.nan) + ax.pcolormesh( + x_edges, + y_edges, + clipped_histogram.T / smoothed_histogram.max(), + **{"cmap": "rainbow"} | (pcolormesh_kws or {}), + ) + + if contour: + contour_histogram = gaussian_filter(histogram, contour_smoothing) + + ax.contour( + x_centers, + y_centers, + contour_histogram.T / contour_histogram.max(), + **{"levels": 3} | (contour_kws or {}), + ) + + ax.set_xlabel(f"{self.PRETTY_DIMENSION_LABELS[x_dimension]}") + ax.set_ylabel(f"{self.PRETTY_DIMENSION_LABELS[y_dimension]}") + + # Handle units + if x_dimension in ("x", "y", "tau"): + x_base_unit = "m" + elif x_dimension in ("px", "py", "p"): + x_base_unit = "%" + + if y_dimension in ("x", "y", "tau"): + y_base_unit = "m" + elif y_dimension in ("px", "py", "p"): + y_base_unit = "%" + + if x_dimension in ("x", "y", "tau"): + format_axis_with_prefixed_unit(ax.xaxis, x_base_unit, x_centers) + elif x_dimension in ("px", "py", "p"): + format_axis_as_percentage(ax.xaxis) + + if y_dimension in ("x", "y", "tau"): + format_axis_with_prefixed_unit(ax.yaxis, y_base_unit, y_centers) + elif y_dimension in ("px", "py", "p"): + format_axis_as_percentage(ax.yaxis) + + return ax + + def plot_distribution( + self, + dimensions: Tuple[str, ...] = ("x", "px", "y", "py", "tau", "p"), + bins: int = 100, + bin_ranges: Optional[ + Union[Literal["same"], Tuple[float], List[Tuple[float]]] + ] = None, + plot_1d_kws: Optional[dict] = None, + plot_2d_kws: Optional[dict] = None, + ) -> plt.Figure: + """ + Plot of coordinates projected into 2D planes. + + :param dimensions: Tuple of dimensions to plot. Should be a subset of + `('x', 'px', 'y', 'py', 'tau', 'p')`. + :param contour: If `True`, overlay contour lines on the 2D histogram plots. + :param bins: Number of bins to use for the histograms. + :param bin_ranges: Ranges of the bins to use for the histograms. If set to + `"unit_same"`, the same range is used for all dimensions that share the same + unit. If set to `None`, ranges are determined automatically. + :param smoothing: Standard deviation of the Gaussian kernel used to smooth the + histograms. + :param plot_1d_kws: Additional keyword arguments to be passed to + `ParticleBeam.plot_1d_distribution` for plotting 1D histograms. + :param plot_2d_kws: Additional keyword arguments to be passed to + `ParticleBeam.plot_2d_distribution` for plotting 2D histograms. + :return: Matplotlib figure object. + """ + fig, axs = plt.subplots( + len(dimensions), + len(dimensions), + figsize=(2 * len(dimensions), 2 * len(dimensions)), + ) + + # Determine bin ranges for all plots in the grid at once + full_tensor = ( + torch.stack([getattr(self, dimension) for dimension in dimensions], dim=-2) + .cpu() + .detach() + .numpy() + ) + if bin_ranges is None: + bin_ranges = [ + ( + full_tensor[i, :].min() + - (full_tensor[i, :].max() - full_tensor[i, :].min()) / 10, + full_tensor[i, :].max() + + (full_tensor[i, :].max() - full_tensor[i, :].min()) / 10, + ) + for i in range(full_tensor.shape[-2]) + ] + if bin_ranges == "unit_same": + spacial_idxs = [ + i + for i, dimension in enumerate(dimensions) + if dimension in ["x", "y", "tau"] + ] + spacial_bin_range = ( + full_tensor[spacial_idxs, :].min() + - ( + full_tensor[spacial_idxs, :].max() + - full_tensor[spacial_idxs, :].min() + ) + / 10, + full_tensor[spacial_idxs, :].max() + + ( + full_tensor[spacial_idxs, :].max() + - full_tensor[spacial_idxs, :].min() + ) + / 10, + ) + unitless_idxs = [ + i + for i, dimension in enumerate(dimensions) + if dimension in ["px", "py", "p"] + ] + unitless_bin_range = ( + full_tensor[unitless_idxs, :].min() + - ( + full_tensor[unitless_idxs, :].max() + - full_tensor[unitless_idxs, :].min() + ) + / 10, + full_tensor[unitless_idxs, :].max() + + ( + full_tensor[unitless_idxs, :].max() + - full_tensor[unitless_idxs, :].min() + ) + / 10, + ) + bin_range_dict = { + "x": spacial_bin_range, + "px": unitless_bin_range, + "y": spacial_bin_range, + "py": unitless_bin_range, + "tau": spacial_bin_range, + "p": unitless_bin_range, + } + bin_ranges = [bin_range_dict[dimension] for dimension in dimensions] + if np.asarray(bin_ranges).shape == (2,): + bin_ranges = [bin_ranges] * len(dimensions) + assert len(bin_ranges) == len(dimensions) and all( + len(e) == 2 for e in bin_ranges + ) + + # Plot diagonal 1D histograms on the diagonal + diagonal_axs = [axs[i, i] for i, _ in enumerate(dimensions)] + for dimension, bin_range, ax in zip(dimensions, bin_ranges, diagonal_axs): + self.plot_1d_distribution( + dimension=dimension, + bins=bins, + bin_range=bin_range, + ax=ax, + **(plot_1d_kws or {}), + ) + + # Plot 2D histograms on the off-diagonal + for i, j in itertools.combinations(range(len(dimensions)), 2): + self.plot_2d_distribution( + x_dimension=dimensions[i], + y_dimension=dimensions[j], + bins=bins, + bin_ranges=(bin_ranges[i], bin_ranges[j]), + ax=axs[j, i], + **(plot_2d_kws or {}), + ) + + # Hide unused axes + for i, j in itertools.combinations(range(len(dimensions)), 2): + axs[i, j].set_visible(False) + + # Clean up labels + for ax_column in axs.T: + for ax in ax_column[0:-1]: + ax.sharex(ax_column[0]) + ax.xaxis.set_tick_params(labelbottom=False) + ax.set_xlabel(None) + for i, ax_row in enumerate(axs): + for ax in ax_row[1:i]: + ax.sharey(ax_row[0]) + ax.yaxis.set_tick_params(labelleft=False) + ax.set_ylabel(None) + for i, _ in enumerate(dimensions): + axs[i, i].sharey(axs[0, 0]) + axs[i, i].set_yticks([]) + axs[i, i].set_ylabel(None) + + return fig + + def plot_point_cloud( + self, scatter_kws: Optional[dict] = None, ax: Optional[plt.Axes] = None + ) -> plt.Axes: + """ + Plot a 3D point cloud of the spatial coordinates of the particles. + + :param scatter_kws: Additional keyword arguments to be passed to the `scatter` + plotting function of matplotlib. + :param ax: Matplotlib axes object to use for plotting. + :return: Matplotlib axes object with the plot. + """ + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(projection="3d") + + x = self.x.cpu().detach().numpy() + tau = self.tau.cpu().detach().numpy() + y = self.y.cpu().detach().numpy() + + ax.scatter(x, tau, y, c=self.p.cpu().detach().numpy(), **(scatter_kws or {})) + ax.set_xlabel(f"{self.PRETTY_DIMENSION_LABELS['x']}") + ax.set_ylabel(f"{self.PRETTY_DIMENSION_LABELS['tau']}") + ax.set_zlabel(f"{self.PRETTY_DIMENSION_LABELS['y']}") + + # Handle units + format_axis_with_prefixed_unit(ax.xaxis, "m", x) + format_axis_with_prefixed_unit(ax.yaxis, "m", tau) + format_axis_with_prefixed_unit(ax.zaxis, "m", y) + + return ax + def __len__(self) -> int: return int(self.num_particles) diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index ba74ef57..f015bd0c 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -4,6 +4,10 @@ from .elementwise_linspace import elementwise_linspace # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 from .physics import compute_relativistic_factors # noqa: F401 +from .plot import ( # noqa: F401 + format_axis_as_percentage, + format_axis_with_prefixed_unit, +) from .statistics import ( # noqa: F401 unbiased_weighted_covariance, unbiased_weighted_std, diff --git a/cheetah/utils/plot.py b/cheetah/utils/plot.py new file mode 100644 index 00000000..1b33b979 --- /dev/null +++ b/cheetah/utils/plot.py @@ -0,0 +1,71 @@ +import matplotlib +import numpy as np + + +def format_axis_with_prefixed_unit( + axis: matplotlib.axis.Axis, base_unit: str, data: list[float] +) -> None: + """ + Adds an appropriately prefixed unit to the axis label and sets the tick formatter + accordingly to best match the given data. + """ + prefixed_unit, tick_formatter = determine_prefixed_unit_and_tick_formatter( + base_unit, data + ) + axis.set_label_text(f"{axis.get_label_text()} ({prefixed_unit})") + axis.set_major_formatter(tick_formatter) + axis.set_minor_formatter(tick_formatter) + + +def format_axis_as_percentage(axis: matplotlib.axis.Axis) -> None: + """ + Adds a percentage symbol to the axis label and sets the tick formatter accordingly. + """ + axis.set_label_text(f"{axis.get_label_text()} (%)") + axis.set_major_formatter(NoSymbolPercentFormatter()) + axis.set_minor_formatter(NoSymbolPercentFormatter()) + + +def determine_prefixed_unit_and_tick_formatter( + base_unit: str, data: list[float] +) -> tuple[str, matplotlib.ticker.FuncFormatter]: + """ + Considering the order of magnitude of some data points and their base unit, + determines the prefixed unit and the corresponding matplotlib tick formatter. + """ + if 1.0 <= np.max(np.abs(data)) < 1e3: + return base_unit, IdentityFormatter() + elif 1e-3 <= np.max(np.abs(data)) < 1.0: + return f"m{base_unit}", MilliFormatter() + elif 1e-6 <= np.max(np.abs(data)) < 1e-3: + return f"μ{base_unit}", MicroFormatter() + else: + return base_unit, IdentityFormatter() + + +class NoSymbolPercentFormatter(matplotlib.ticker.FuncFormatter): + """Formatter for percentages without the percent symbol.""" + + def __init__(self): + super().__init__(lambda x, _: f"{x * 100:.1f}") + + +class IdentityFormatter(matplotlib.ticker.FuncFormatter): + """Formatter for base values.""" + + def __init__(self): + super().__init__(lambda x, _: f"{x:.0f}") + + +class MilliFormatter(matplotlib.ticker.FuncFormatter): + """Formatter for milli values.""" + + def __init__(self): + super().__init__(lambda x, _: f"{x * 1e3:.0f}") + + +class MicroFormatter(matplotlib.ticker.FuncFormatter): + """Formatter for micro values.""" + + def __init__(self): + super().__init__(lambda x, _: f"{x * 1e6:.0f}") diff --git a/tests/test_plotting.py b/tests/test_plotting.py index f3b3f01b..dc3edea5 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -138,3 +138,19 @@ def test_plotting_with_gradients(): segment.plot_overview(incoming=beam) segment.plot_twiss(incoming=beam) + + +def test_plot_6d_particle_beam_distribution(): + """Test that the 6D `ParticleBeam` distribution plot does not raise an exception.""" + beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") + + # Run the plotting to see if it raises an exception + _ = beam.plot_distribution(bin_ranges="unit_same", plot_2d_kws={"contour": True}) + + +def test_plot_particle_beam_point_cloud(): + """Test that the `ParticleBeam`'s point cloud plot does not raise an exception.""" + beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") + + # Run the plotting to see if it raises an exception + _ = beam.plot_point_cloud()