Skip to content

Commit

Permalink
Merge pull request #292 from roussel-ryan/distribution-plotting
Browse files Browse the repository at this point in the history
Distribution plotting
  • Loading branch information
jank324 authored Dec 12, 2024
2 parents a784bb0 + 65527b4 commit b92e3c9
Show file tree
Hide file tree
Showing 5 changed files with 432 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe)
- Implement a converter for lattice files imported from Elegant (see #222, #251, #273, #281) (@hespe, @jank324)
- `Beam` and `Element` objects now have a `.clone()` method to create a deep copy (see #289) (@hespe, @jank324)
- `ParticleBeam` now comes with methods for plotting the beam distribution in a variety of ways (see #292) (@roussel-ryan, @jank324)

### 🐛 Bug fixes

Expand Down
341 changes: 340 additions & 1 deletion cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from typing import Optional, Union
import itertools
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import torch
from matplotlib import pyplot as plt
from scipy import constants
from scipy.constants import physical_constants
from scipy.ndimage import gaussian_filter
from torch.distributions import MultivariateNormal

from cheetah.particles.beam import Beam
from cheetah.utils import (
elementwise_linspace,
format_axis_as_percentage,
format_axis_with_prefixed_unit,
unbiased_weighted_covariance,
unbiased_weighted_std,
verify_device_and_dtype,
Expand Down Expand Up @@ -35,6 +41,15 @@ class ParticleBeam(Beam):
:param dtype: Data type of the generated particles.
"""

PRETTY_DIMENSION_LABELS = {
"x": r"$x$",
"px": r"$p_x$",
"y": r"$y$",
"py": r"$p_y$",
"tau": r"$\tau$",
"p": r"$\delta$",
}

def __init__(
self,
particles: torch.Tensor,
Expand Down Expand Up @@ -883,6 +898,330 @@ def to_xyz_pxpypz(self) -> torch.Tensor:

return xp_coords

def plot_1d_distribution(
self,
dimension: Literal["x", "px", "y", "py", "tau", "p"],
bins: int = 100,
bin_range: Optional[Tuple[float]] = None,
smoothing: float = 0.0,
plot_kws: Optional[dict] = None,
ax: Optional[plt.Axes] = None,
) -> plt.Axes:
"""
Plot a 1D histogram of the given dimension of the particle distribution.
:param dimension: Name of the dimension to plot. Should be one of
`('x', 'px', 'y', 'py', 'tau', 'p')`.
:param bins: Number of bins to use for the histogram.
:param bin_range: Range of the bins to use for the histogram.
:param smoothing: Standard deviation of the Gaussian kernel used to smooth the
histogram.
:param plot_kws: Additional keyword arguments to be passed to `plot` function of
matplotlib used to plot the histogram data.
:param ax: Matplotlib axes object to use for plotting.
:return: Matplotlib axes object with the plot.
"""
if ax is None:
_, ax = plt.subplots()

x_array = getattr(self, dimension).cpu().detach().numpy()
histogram, edges = np.histogram(x_array, bins=bins, range=bin_range)
centers = (edges[:-1] + edges[1:]) / 2

if smoothing:
histogram = gaussian_filter(histogram, smoothing)

ax.plot(
centers,
histogram / histogram.max(),
**{"color": "black"} | (plot_kws or {}),
)
ax.set_xlabel(f"{self.PRETTY_DIMENSION_LABELS[dimension]}")

# Handle units
if dimension in ("x", "y", "tau"):
base_unit = "m"
elif dimension in ("px", "py", "p"):
base_unit = "%"

if dimension in ("x", "y", "tau"):
format_axis_with_prefixed_unit(ax.xaxis, base_unit, centers)
elif dimension in ("px", "py", "p"):
format_axis_as_percentage(ax.xaxis)

return ax

def plot_2d_distribution(
self,
x_dimension: Literal["x", "px", "y", "py", "tau", "p"],
y_dimension: Literal["x", "px", "y", "py", "tau", "p"],
contour: bool = False,
bins: int = 100,
bin_ranges: Optional[Tuple[Tuple[float]]] = None,
histogram_smoothing: float = 0.0,
contour_smoothing: float = 3.0,
pcolormesh_kws: Optional[dict] = None,
contour_kws: Optional[dict] = None,
ax: Optional[plt.Axes] = None,
) -> plt.Axes:
"""
Plot a 2D histogram of the given dimensions of the particle distribution.
:param x_dimension: Name of the x dimension to plot. Should be one of
`('x', 'px', 'y', 'py', 'tau', 'p')`.
:param y_dimension: Name of the y dimension to plot. Should be one of
`('x', 'px', 'y', 'py', 'tau', 'p')`.
:param contour: If `True`, overlay contour lines on the 2D histogram plot.
:param bins: Number of bins to use for the histogram in both dimensions.
:param bin_ranges: Ranges of the bins to use for the histogram in each
dimension.
:param smoothing: Standard deviation of the Gaussian kernel used to smooth the
histogram.
:param pcolormesh_kws: Additional keyword arguments to be passed to `pcolormesh`
function of matplotlib used to plot the histogram data.
:param contour_kws: Additional keyword arguments to be passed to `contour`
function of matplotlib used to plot the histogram data.
:param ax: Matplotlib axes object to use for plotting.
:return: Matplotlib axes object with the plot.
"""
if ax is None:
_, ax = plt.subplots()

histogram, x_edges, y_edges = np.histogram2d(
getattr(self, x_dimension).cpu().detach().numpy(),
getattr(self, y_dimension).cpu().detach().numpy(),
bins=bins,
range=bin_ranges,
)
x_centers = (x_edges[:-1] + x_edges[1:]) / 2
y_centers = (y_edges[:-1] + y_edges[1:]) / 2

# Post-process and plot
smoothed_histogram = gaussian_filter(histogram, histogram_smoothing)
clipped_histogram = np.where(smoothed_histogram > 1, smoothed_histogram, np.nan)
ax.pcolormesh(
x_edges,
y_edges,
clipped_histogram.T / smoothed_histogram.max(),
**{"cmap": "rainbow"} | (pcolormesh_kws or {}),
)

if contour:
contour_histogram = gaussian_filter(histogram, contour_smoothing)

ax.contour(
x_centers,
y_centers,
contour_histogram.T / contour_histogram.max(),
**{"levels": 3} | (contour_kws or {}),
)

ax.set_xlabel(f"{self.PRETTY_DIMENSION_LABELS[x_dimension]}")
ax.set_ylabel(f"{self.PRETTY_DIMENSION_LABELS[y_dimension]}")

# Handle units
if x_dimension in ("x", "y", "tau"):
x_base_unit = "m"
elif x_dimension in ("px", "py", "p"):
x_base_unit = "%"

if y_dimension in ("x", "y", "tau"):
y_base_unit = "m"
elif y_dimension in ("px", "py", "p"):
y_base_unit = "%"

if x_dimension in ("x", "y", "tau"):
format_axis_with_prefixed_unit(ax.xaxis, x_base_unit, x_centers)
elif x_dimension in ("px", "py", "p"):
format_axis_as_percentage(ax.xaxis)

if y_dimension in ("x", "y", "tau"):
format_axis_with_prefixed_unit(ax.yaxis, y_base_unit, y_centers)
elif y_dimension in ("px", "py", "p"):
format_axis_as_percentage(ax.yaxis)

return ax

def plot_distribution(
self,
dimensions: Tuple[str, ...] = ("x", "px", "y", "py", "tau", "p"),
bins: int = 100,
bin_ranges: Optional[
Union[Literal["same"], Tuple[float], List[Tuple[float]]]
] = None,
plot_1d_kws: Optional[dict] = None,
plot_2d_kws: Optional[dict] = None,
) -> plt.Figure:
"""
Plot of coordinates projected into 2D planes.
:param dimensions: Tuple of dimensions to plot. Should be a subset of
`('x', 'px', 'y', 'py', 'tau', 'p')`.
:param contour: If `True`, overlay contour lines on the 2D histogram plots.
:param bins: Number of bins to use for the histograms.
:param bin_ranges: Ranges of the bins to use for the histograms. If set to
`"unit_same"`, the same range is used for all dimensions that share the same
unit. If set to `None`, ranges are determined automatically.
:param smoothing: Standard deviation of the Gaussian kernel used to smooth the
histograms.
:param plot_1d_kws: Additional keyword arguments to be passed to
`ParticleBeam.plot_1d_distribution` for plotting 1D histograms.
:param plot_2d_kws: Additional keyword arguments to be passed to
`ParticleBeam.plot_2d_distribution` for plotting 2D histograms.
:return: Matplotlib figure object.
"""
fig, axs = plt.subplots(
len(dimensions),
len(dimensions),
figsize=(2 * len(dimensions), 2 * len(dimensions)),
)

# Determine bin ranges for all plots in the grid at once
full_tensor = (
torch.stack([getattr(self, dimension) for dimension in dimensions], dim=-2)
.cpu()
.detach()
.numpy()
)
if bin_ranges is None:
bin_ranges = [
(
full_tensor[i, :].min()
- (full_tensor[i, :].max() - full_tensor[i, :].min()) / 10,
full_tensor[i, :].max()
+ (full_tensor[i, :].max() - full_tensor[i, :].min()) / 10,
)
for i in range(full_tensor.shape[-2])
]
if bin_ranges == "unit_same":
spacial_idxs = [
i
for i, dimension in enumerate(dimensions)
if dimension in ["x", "y", "tau"]
]
spacial_bin_range = (
full_tensor[spacial_idxs, :].min()
- (
full_tensor[spacial_idxs, :].max()
- full_tensor[spacial_idxs, :].min()
)
/ 10,
full_tensor[spacial_idxs, :].max()
+ (
full_tensor[spacial_idxs, :].max()
- full_tensor[spacial_idxs, :].min()
)
/ 10,
)
unitless_idxs = [
i
for i, dimension in enumerate(dimensions)
if dimension in ["px", "py", "p"]
]
unitless_bin_range = (
full_tensor[unitless_idxs, :].min()
- (
full_tensor[unitless_idxs, :].max()
- full_tensor[unitless_idxs, :].min()
)
/ 10,
full_tensor[unitless_idxs, :].max()
+ (
full_tensor[unitless_idxs, :].max()
- full_tensor[unitless_idxs, :].min()
)
/ 10,
)
bin_range_dict = {
"x": spacial_bin_range,
"px": unitless_bin_range,
"y": spacial_bin_range,
"py": unitless_bin_range,
"tau": spacial_bin_range,
"p": unitless_bin_range,
}
bin_ranges = [bin_range_dict[dimension] for dimension in dimensions]
if np.asarray(bin_ranges).shape == (2,):
bin_ranges = [bin_ranges] * len(dimensions)
assert len(bin_ranges) == len(dimensions) and all(
len(e) == 2 for e in bin_ranges
)

# Plot diagonal 1D histograms on the diagonal
diagonal_axs = [axs[i, i] for i, _ in enumerate(dimensions)]
for dimension, bin_range, ax in zip(dimensions, bin_ranges, diagonal_axs):
self.plot_1d_distribution(
dimension=dimension,
bins=bins,
bin_range=bin_range,
ax=ax,
**(plot_1d_kws or {}),
)

# Plot 2D histograms on the off-diagonal
for i, j in itertools.combinations(range(len(dimensions)), 2):
self.plot_2d_distribution(
x_dimension=dimensions[i],
y_dimension=dimensions[j],
bins=bins,
bin_ranges=(bin_ranges[i], bin_ranges[j]),
ax=axs[j, i],
**(plot_2d_kws or {}),
)

# Hide unused axes
for i, j in itertools.combinations(range(len(dimensions)), 2):
axs[i, j].set_visible(False)

# Clean up labels
for ax_column in axs.T:
for ax in ax_column[0:-1]:
ax.sharex(ax_column[0])
ax.xaxis.set_tick_params(labelbottom=False)
ax.set_xlabel(None)
for i, ax_row in enumerate(axs):
for ax in ax_row[1:i]:
ax.sharey(ax_row[0])
ax.yaxis.set_tick_params(labelleft=False)
ax.set_ylabel(None)
for i, _ in enumerate(dimensions):
axs[i, i].sharey(axs[0, 0])
axs[i, i].set_yticks([])
axs[i, i].set_ylabel(None)

return fig

def plot_point_cloud(
self, scatter_kws: Optional[dict] = None, ax: Optional[plt.Axes] = None
) -> plt.Axes:
"""
Plot a 3D point cloud of the spatial coordinates of the particles.
:param scatter_kws: Additional keyword arguments to be passed to the `scatter`
plotting function of matplotlib.
:param ax: Matplotlib axes object to use for plotting.
:return: Matplotlib axes object with the plot.
"""
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")

x = self.x.cpu().detach().numpy()
tau = self.tau.cpu().detach().numpy()
y = self.y.cpu().detach().numpy()

ax.scatter(x, tau, y, c=self.p.cpu().detach().numpy(), **(scatter_kws or {}))
ax.set_xlabel(f"{self.PRETTY_DIMENSION_LABELS['x']}")
ax.set_ylabel(f"{self.PRETTY_DIMENSION_LABELS['tau']}")
ax.set_zlabel(f"{self.PRETTY_DIMENSION_LABELS['y']}")

# Handle units
format_axis_with_prefixed_unit(ax.xaxis, "m", x)
format_axis_with_prefixed_unit(ax.yaxis, "m", tau)
format_axis_with_prefixed_unit(ax.zaxis, "m", y)

return ax

def __len__(self) -> int:
return int(self.num_particles)

Expand Down
4 changes: 4 additions & 0 deletions cheetah/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from .elementwise_linspace import elementwise_linspace # noqa: F401
from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401
from .physics import compute_relativistic_factors # noqa: F401
from .plot import ( # noqa: F401
format_axis_as_percentage,
format_axis_with_prefixed_unit,
)
from .statistics import ( # noqa: F401
unbiased_weighted_covariance,
unbiased_weighted_std,
Expand Down
Loading

0 comments on commit b92e3c9

Please sign in to comment.