diff --git a/CHANGELOG.md b/CHANGELOG.md index 7efbeebc..dec4d198 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des ### 🚨 Breaking Changes - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan) +- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). This also comes with the removal of `Beam.empty`. Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) - The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 6999efdd..71f1b329 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -15,8 +15,11 @@ class Aperture(Element): """ Physical aperture. - :param x_max: half size horizontal offset in [m] - :param y_max: half size vertical offset in [m] + NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has + an effect when the aperture is active. + + :param x_max: half size horizontal offset in [m]. + :param y_max: half size vertical offset in [m]. :param shape: Shape of the aperture. Can be "rectangular" or "elliptical". :param is_active: If the aperture actually blocks particles. :param name: Unique identifier of the element. @@ -72,7 +75,7 @@ def track(self, incoming: Beam) -> Beam: if not (isinstance(incoming, ParticleBeam) and self.is_active): return incoming - assert self.x_max >= 0 and self.y_max >= 0 + assert torch.all(self.x_max >= 0) and torch.all(self.y_max >= 0) assert self.shape in [ "rectangular", "elliptical", @@ -80,33 +83,28 @@ def track(self, incoming: Beam) -> Beam: if self.shape == "rectangular": survived_mask = torch.logical_and( - torch.logical_and(incoming.x > -self.x_max, incoming.x < self.x_max), - torch.logical_and(incoming.y > -self.y_max, incoming.y < self.y_max), + torch.logical_and( + incoming.x > -self.x_max.unsqueeze(-1), + incoming.x < self.x_max.unsqueeze(-1), + ), + torch.logical_and( + incoming.y > -self.y_max.unsqueeze(-1), + incoming.y < self.y_max.unsqueeze(-1), + ), ) elif self.shape == "elliptical": survived_mask = ( - incoming.x**2 / self.x_max**2 + incoming.y**2 / self.y_max**2 + incoming.x**2 / self.x_max.unsqueeze(-1) ** 2 + + incoming.y**2 / self.y_max.unsqueeze(-1) ** 2 ) <= 1.0 - outgoing_particles = incoming.particles[survived_mask] - - outgoing_particle_charges = incoming.particle_charges[survived_mask] - self.lost_particles = incoming.particles[torch.logical_not(survived_mask)] - - self.lost_particle_charges = incoming.particle_charges[ - torch.logical_not(survived_mask) - ] - - return ( - ParticleBeam( - outgoing_particles, - incoming.energy, - particle_charges=outgoing_particle_charges, - device=outgoing_particles.device, - dtype=outgoing_particles.dtype, - ) - if outgoing_particles.shape[0] > 0 - else ParticleBeam.empty + return ParticleBeam( + particles=incoming.particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities * survived_mask, + device=incoming.particles.device, + dtype=incoming.particles.dtype, ) def split(self, resolution: torch.Tensor) -> list[Element]: diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index d0e636bd..405e8d24 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -37,9 +37,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: ) def track(self, incoming: Beam) -> Beam: - if incoming is Beam.empty: - self.reading = None - elif isinstance(incoming, ParameterBeam): + if isinstance(incoming, ParameterBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) elif isinstance(incoming, ParticleBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 46888e6c..8e4f9f5e 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -103,9 +103,7 @@ def track(self, incoming: Beam) -> Beam: :param incoming: Beam of particles entering the element. :return: Beam of particles exiting the element. """ - if incoming is Beam.empty: - return incoming - elif isinstance(incoming, (ParameterBeam, ParticleBeam)): + if isinstance(incoming, (ParameterBeam, ParticleBeam)): return self._track_beam(incoming) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") @@ -244,9 +242,10 @@ def _track_beam(self, incoming: Beam) -> Beam: return outgoing else: # ParticleBeam outgoing = ParticleBeam( - outgoing_particles, - outgoing_energy, + particles=outgoing_particles, + energy=outgoing_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=outgoing_particles.device, dtype=outgoing_particles.dtype, ) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 6f253be6..393c65f6 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -260,6 +260,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index c76af4ef..ff63f371 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -114,6 +114,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index bfe1df7c..e050d106 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -60,9 +60,7 @@ def track(self, incoming: Beam) -> Beam: :param incoming: Beam of particles entering the element. :return: Beam of particles exiting the element. """ - if incoming is Beam.empty: - return incoming - elif isinstance(incoming, ParameterBeam): + if isinstance(incoming, ParameterBeam): tm = self.transfer_map(incoming.energy) mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1) cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1))) @@ -81,6 +79,7 @@ def track(self, incoming: Beam) -> Beam: new_particles, incoming.energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=new_particles.device, dtype=new_particles.dtype, ) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index cbc7e00c..99ee3c0a 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -187,9 +187,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index e1d6c0ea..93132dd1 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -158,6 +158,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) def track(self, incoming: Beam) -> Beam: + # Record the beam only when the screen is active if self.is_active: copy_of_incoming = deepcopy(incoming) @@ -185,7 +186,26 @@ def track(self, incoming: Beam) -> Beam: self.set_read_beam(copy_of_incoming) - return Beam.empty if self.is_blocking else incoming + # Block the beam only when the screen is active and blocking + if self.is_active and self.is_blocking: + if isinstance(incoming, ParameterBeam): + return ParameterBeam( + mu=incoming._mu, + cov=incoming._cov, + energy=incoming.energy, + total_charge=torch.zeros_like(incoming.total_charge), + ) + elif isinstance(incoming, ParticleBeam): + return ParticleBeam( + particles=incoming.particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=torch.zeros_like( + incoming.survival_probabilities + ), + ) + else: + return deepcopy(incoming) @property def reading(self) -> torch.Tensor: @@ -194,7 +214,7 @@ def reading(self) -> torch.Tensor: return self.cached_reading read_beam = self.get_read_beam() - if read_beam is Beam.empty or read_beam is None: + if read_beam is None: image = torch.zeros( (int(self.effective_resolution[1]), int(self.effective_resolution[0])), device=self.misalignment.device, @@ -255,16 +275,24 @@ def reading(self) -> torch.Tensor: ) image, _ = torch.histogramdd( - torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges + torch.stack((read_beam.x, read_beam.y)).T, + bins=self.pixel_bin_edges, + weight=read_beam.particle_charges + * read_beam.survival_probabilities, ) image = torch.flipud(image.T) elif self.method == "kde": + weights = read_beam.particle_charges * read_beam.survival_probabilities + broadcasted_x, broadcasted_y, broadcasted_weights = ( + torch.broadcast_tensors(read_beam.x, read_beam.y, weights) + ) image = kde_histogram_2d( - x1=read_beam.x, - x2=read_beam.y, + x1=broadcasted_x, + x2=broadcasted_y, bins1=self.pixel_bin_centers[0], bins2=self.pixel_bin_centers[1], bandwidth=self.kde_bandwidth, + weights=broadcasted_weights, ) # Change the x, y positions image = torch.transpose(image, -2, -1) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 4857ccf9..74808395 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -5,7 +5,7 @@ from scipy.constants import elementary_charge, epsilon_0, speed_of_light from cheetah.accelerator.element import Element -from cheetah.particles import Beam, ParticleBeam +from cheetah.particles import ParticleBeam from cheetah.utils import verify_device_and_dtype @@ -149,7 +149,8 @@ def _deposit_charge_on_grid( ) # Accumulate the charge contributions - repeated_charges = beam.particle_charges.repeat_interleave( + survived_particle_charges = beam.particle_charges * beam.survival_probabilities + repeated_charges = survived_particle_charges.repeat_interleave( repeats=8, dim=-1 ) # Shape:(..., 8 * num_particles) values = (cell_weights.flatten(start_dim=-2) * repeated_charges)[valid_mask] @@ -545,27 +546,34 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: :param incoming: Beam of particles entering the element. :returns: Beam of particles exiting the element. """ - if incoming is Beam.empty or incoming.particles.shape[0] == 0: - return incoming - elif isinstance(incoming, ParticleBeam): + if isinstance(incoming, ParticleBeam): # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. - # Make sure that the incoming beam has at least one vector dimension - if len(incoming.particles.shape) == 2: - is_incoming_vectorized = False - - vectorized_incoming = ParticleBeam( - particles=incoming.particles.unsqueeze(0), - energy=incoming.energy.unsqueeze(0), - particle_charges=incoming.particle_charges.unsqueeze(0), - device=incoming.particles.device, - dtype=incoming.particles.dtype, - ) - else: - is_incoming_vectorized = True - - vectorized_incoming = incoming + # Make sure that the incoming beam has at least one vector dimension by + # broadcasting with a dummy dimension (1,). + vector_shape = torch.broadcast_shapes( + incoming.particles.shape[:-2], + incoming.energy.shape, + incoming.particle_charges.shape[:-1], + incoming.survival_probabilities.shape[:-1], + (1,), + ) + vectorized_incoming = ParticleBeam( + particles=torch.broadcast_to( + incoming.particles, (*vector_shape, incoming.num_particles, 7) + ), + energy=torch.broadcast_to(incoming.energy, vector_shape), + particle_charges=torch.broadcast_to( + incoming.particle_charges, (*vector_shape, incoming.num_particles) + ), + survival_probabilities=torch.broadcast_to( + incoming.survival_probabilities, + (*vector_shape, incoming.num_particles), + ), + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) flattened_incoming = ParticleBeam( particles=vectorized_incoming.particles.flatten(end_dim=-3), @@ -573,6 +581,9 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=vectorized_incoming.particle_charges.flatten( end_dim=-2 ), + survival_probabilities=( + vectorized_incoming.survival_probabilities.flatten(end_dim=-2) + ), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, ) @@ -611,26 +622,25 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ..., 2 ] * dt.unsqueeze(-1) - if not is_incoming_vectorized: - # Reshape to the original non-vectorised shape - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.squeeze(0), - vectorized_incoming.energy.squeeze(0), - vectorized_incoming.particle_charges.squeeze(0), - vectorized_incoming.particles.device, - vectorized_incoming.particles.dtype, - ) - else: - # Reverse the flattening of the vector dimensions - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.unflatten( - dim=0, sizes=vectorized_incoming.particles.shape[:-2] - ), - vectorized_incoming.energy, - vectorized_incoming.particle_charges, - vectorized_incoming.particles.device, - vectorized_incoming.particles.dtype, - ) + # Reverse the flattening of the vector dimensions + outgoing_vector_shape = torch.broadcast_shapes( + incoming.particles.shape[:-2], + incoming.energy.shape, + incoming.particle_charges.shape[:-1], + incoming.survival_probabilities.shape[:-1], + self.effect_length.shape, + ) + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates=xp_coordinates.reshape( + (*outgoing_vector_shape, incoming.num_particles, 7) + ), + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) + return outgoing else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index aecab3fd..6f33417d 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -207,9 +207,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index 0e9b428b..5a331dc5 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -33,8 +33,6 @@ class directly, but use one of the subclasses. :math:`\Delta E = E - E_0` """ - empty = "I'm an empty beam!" - @classmethod @abstractmethod def from_parameters( diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 20e9223e..23b078a2 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -6,7 +6,12 @@ from torch.distributions import MultivariateNormal from cheetah.particles.beam import Beam -from cheetah.utils import elementwise_linspace, verify_device_and_dtype +from cheetah.utils import ( + elementwise_linspace, + unbiased_weighted_covariance, + unbiased_weighted_std, + verify_device_and_dtype, +) speed_of_light = torch.tensor(constants.speed_of_light) # In m/s electron_mass = torch.tensor(constants.electron_mass) # In kg @@ -21,7 +26,10 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. - :param total_charge: Total charge of the beam in C. + :param particle_charges: Charges of the macroparticles in the beam in C. + :param survival_probabilities: Vector of probabilities that each particle has + survived (i.e. not been lost), where 1.0 means the particle has survived and + 0.0 means the particle has been lost. Defaults to ones. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. :param dtype: Data type of the generated particles. @@ -32,6 +40,7 @@ def __init__( particles: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, + survival_probabilities: Optional[torch.Tensor] = None, device=None, dtype=None, ) -> None: @@ -51,10 +60,18 @@ def __init__( ( particle_charges.to(**factory_kwargs) if particle_charges is not None - else torch.zeros(particles.shape[:2], **factory_kwargs) + else torch.zeros(particles.shape[-2], **factory_kwargs) ), ) self.register_buffer("energy", energy.to(**factory_kwargs)) + self.register_buffer( + "survival_probabilities", + ( + survival_probabilities.to(**factory_kwargs) + if survival_probabilities is not None + else torch.ones(particles.shape[-2], **factory_kwargs) + ), + ) @classmethod def from_parameters( @@ -795,6 +812,7 @@ def from_xyz_pxpypz( xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, + survival_probabilities: Optional[torch.Tensor] = None, device=None, dtype=torch.float32, ) -> torch.Tensor: @@ -807,6 +825,7 @@ def from_xyz_pxpypz( particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, + survival_probabilities=survival_probabilities, device=device, dtype=dtype, ) @@ -872,15 +891,26 @@ def __len__(self) -> int: @property def total_charge(self) -> torch.Tensor: - return torch.sum(self.particle_charges, dim=-1) + """Total charge of the beam in C, taking into account particle losses.""" + return torch.sum(self.particle_charges * self.survival_probabilities, dim=-1) @property def num_particles(self) -> int: + """ + Length of the macroparticle array. + + NOTE: This does not account for lost particles. + """ return self.particles.shape[-2] + @property + def num_particles_survived(self) -> torch.Tensor: + """Number of macroparticles that have survived.""" + return self.survival_probabilities.sum(dim=-1) + @property def x(self) -> Optional[torch.Tensor]: - return self.particles[..., 0] if self is not Beam.empty else None + return self.particles[..., 0] @x.setter def x(self, value: torch.Tensor) -> None: @@ -888,15 +918,27 @@ def x(self, value: torch.Tensor) -> None: @property def mu_x(self) -> Optional[torch.Tensor]: - return self.x.mean(dim=-1) if self is not Beam.empty else None + """ + Mean of the :math:`x` coordinates of the particles, weighted by their + survival probability. + """ + return torch.sum( + (self.x * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_x(self) -> Optional[torch.Tensor]: - return self.x.std(dim=-1) if self is not Beam.empty else None + """ + Standard deviation of the :math:`x` coordinates of the particles, weighted + by their survival probability. + """ + return unbiased_weighted_std( + self.x, weights=self.survival_probabilities, dim=-1 + ) @property def px(self) -> Optional[torch.Tensor]: - return self.particles[..., 1] if self is not Beam.empty else None + return self.particles[..., 1] @px.setter def px(self, value: torch.Tensor) -> None: @@ -904,15 +946,27 @@ def px(self, value: torch.Tensor) -> None: @property def mu_px(self) -> Optional[torch.Tensor]: - return self.px.mean(dim=-1) if self is not Beam.empty else None + """ + Mean of the :math:`px` coordinates of the particles, weighted by their + survival probability. + """ + return torch.sum( + (self.px * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_px(self) -> Optional[torch.Tensor]: - return self.px.std(dim=-1) if self is not Beam.empty else None + """ + Standard deviation of the :math:`px` coordinates of the particles, weighted + by their survival probability. + """ + return unbiased_weighted_std( + self.px, weights=self.survival_probabilities, dim=-1 + ) @property def y(self) -> Optional[torch.Tensor]: - return self.particles[..., 2] if self is not Beam.empty else None + return self.particles[..., 2] @y.setter def y(self, value: torch.Tensor) -> None: @@ -920,15 +974,19 @@ def y(self, value: torch.Tensor) -> None: @property def mu_y(self) -> Optional[float]: - return self.y.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.y * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_y(self) -> Optional[torch.Tensor]: - return self.y.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.y, weights=self.survival_probabilities, dim=-1 + ) @property def py(self) -> Optional[torch.Tensor]: - return self.particles[..., 3] if self is not Beam.empty else None + return self.particles[..., 3] @py.setter def py(self, value: torch.Tensor) -> None: @@ -936,15 +994,19 @@ def py(self, value: torch.Tensor) -> None: @property def mu_py(self) -> Optional[torch.Tensor]: - return self.py.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.py * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_py(self) -> Optional[torch.Tensor]: - return self.py.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.py, weights=self.survival_probabilities, dim=-1 + ) @property def tau(self) -> Optional[torch.Tensor]: - return self.particles[..., 4] if self is not Beam.empty else None + return self.particles[..., 4] @tau.setter def tau(self, value: torch.Tensor) -> None: @@ -952,15 +1014,19 @@ def tau(self, value: torch.Tensor) -> None: @property def mu_tau(self) -> Optional[torch.Tensor]: - return self.tau.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.tau * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_tau(self) -> Optional[torch.Tensor]: - return self.tau.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.tau, weights=self.survival_probabilities, dim=-1 + ) @property def p(self) -> Optional[torch.Tensor]: - return self.particles[..., 5] if self is not Beam.empty else None + return self.particles[..., 5] @p.setter def p(self, value: torch.Tensor) -> None: @@ -968,24 +1034,34 @@ def p(self, value: torch.Tensor) -> None: @property def mu_p(self) -> Optional[torch.Tensor]: - return self.p.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.p * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_p(self) -> Optional[torch.Tensor]: - return self.p.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.p, weights=self.survival_probabilities, dim=-1 + ) @property def sigma_xpx(self) -> torch.Tensor: - return torch.mean( - (self.x - self.mu_x.unsqueeze(-1)) * (self.px - self.mu_px.unsqueeze(-1)), - dim=-1, + r""" + Returns the covariance between x and px. :math:`\sigma_{x, px}^2`. + It is weighted by the survival probability of the particles. + """ + return unbiased_weighted_covariance( + self.x, self.px, weights=self.survival_probabilities, dim=-1 ) @property def sigma_ypy(self) -> torch.Tensor: - return torch.mean( - (self.y - self.mu_y.unsqueeze(-1)) * (self.py - self.mu_py.unsqueeze(-1)), - dim=-1, + r""" + Returns the covariance between y and py. :math:`\sigma_{y, py}^2`. + It is weighted by the survival probability of the particles. + """ + return unbiased_weighted_covariance( + self.y, self.py, weights=self.survival_probabilities, dim=-1 ) @property diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index e65f40c6..ba74ef57 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -4,4 +4,9 @@ from .elementwise_linspace import elementwise_linspace # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 from .physics import compute_relativistic_factors # noqa: F401 +from .statistics import ( # noqa: F401 + unbiased_weighted_covariance, + unbiased_weighted_std, + unbiased_weighted_variance, +) from .unique_name_generator import UniqueNameGenerator # noqa: F401 diff --git a/cheetah/utils/statistics.py b/cheetah/utils/statistics.py new file mode 100644 index 00000000..adfbeba2 --- /dev/null +++ b/cheetah/utils/statistics.py @@ -0,0 +1,62 @@ +import torch + + +def unbiased_weighted_covariance( + input1: torch.Tensor, input2: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted covariance of two tensors. + + :param input1: Input tensor 1. (..., sample_size) + :param input2: Input tensor 2. (..., sample_size) + :param weights: Weights tensor. (..., sample_size) + :param dim: Dimension along which to compute the covariance. + :return: Unbiased weighted covariance. (..., 2, 2) + """ + weighted_mean1 = torch.sum(input1 * weights, dim=dim) / torch.sum(weights, dim=dim) + weighted_mean2 = torch.sum(input2 * weights, dim=dim) / torch.sum(weights, dim=dim) + correction_factor = torch.sum(weights, dim=dim) - torch.sum( + weights**2, dim=dim + ) / torch.sum(weights, dim=dim) + covariance = torch.sum( + weights + * (input1 - weighted_mean1.unsqueeze(-1)) + * (input2 - weighted_mean2.unsqueeze(-1)), + dim=dim, + ) / (correction_factor) + return covariance + + +def unbiased_weighted_variance( + input: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted variance of a tensor. + + :param input: Input tensor. + :param weights: Weights tensor. + :param dim: Dimension along which to compute the variance. + :return: Unbiased weighted variance. + """ + weighted_mean = torch.sum(input * weights, dim=dim) / torch.sum(weights, dim=dim) + correction_factor = torch.sum(weights, dim=dim) - torch.sum( + weights**2, dim=dim + ) / torch.sum(weights, dim=dim) + variance = torch.sum( + weights * (input - weighted_mean.unsqueeze(-1)) ** 2, dim=dim + ) / (correction_factor) + return variance + + +def unbiased_weighted_std( + input: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted standard deviation of a tensor. + + :param input: Input tensor. + :param weights: Weights tensor. + :param dim: Dimension along which to compute the standard deviation. + :return: Unbiased weighted standard deviation. + """ + return torch.sqrt(unbiased_weighted_variance(input, weights, dim=dim)) diff --git a/tests/resources/bmadx/incoming.pt b/tests/resources/bmadx/incoming.pt index c901e8aa..5279a3d6 100644 Binary files a/tests/resources/bmadx/incoming.pt and b/tests/resources/bmadx/incoming.pt differ diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 1478454e..bc1ab083 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -176,7 +176,10 @@ def test_aperture(): navigator.activate_apertures() _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) - assert outgoing_beam.num_particles == outgoing_p_array.rparticles.shape[1] + assert ( + int(outgoing_beam.num_particles_survived) + == outgoing_p_array.rparticles.shape[1] + ) def test_aperture_elliptical(): @@ -215,7 +218,13 @@ def test_aperture_elliptical(): navigator.activate_apertures() _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) - assert outgoing_beam.num_particles == outgoing_p_array.rparticles.shape[1] + assert ( + int(outgoing_beam.num_particles_survived) + == outgoing_p_array.rparticles.shape[1] + ) + + assert np.allclose(outgoing_beam.mu_x.cpu().numpy(), outgoing_p_array.x().mean()) + assert np.allclose(outgoing_beam.mu_px.cpu().numpy(), outgoing_p_array.px().mean()) def test_solenoid(): diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 08d016e4..231e7ddf 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -121,19 +121,20 @@ def test_dipole_bmadx_tracking(dtype): dtype ) - angle = torch.tensor([20 * torch.pi / 180], dtype=dtype) + # TODO: See if Bmad-X test dtypes can be cleaned up now that dtype PR was merged + angle = torch.tensor(20 * torch.pi / 180, dtype=dtype) e1 = angle / 2 e2 = angle - e1 dipole_cheetah_bmadx = Dipole( - length=torch.tensor([0.5]), + length=torch.tensor(0.5), angle=angle, e1=e1, e2=e2, - tilt=torch.tensor([0.1], dtype=dtype), - fringe_integral=torch.tensor([0.5]), - fringe_integral_exit=torch.tensor([0.5]), - gap=torch.tensor([0.05], dtype=dtype), - gap_exit=torch.tensor([0.05], dtype=dtype), + tilt=torch.tensor(0.1, dtype=dtype), + fringe_integral=torch.tensor(0.5), + fringe_integral_exit=torch.tensor(0.5), + gap=torch.tensor(0.05, dtype=dtype), + gap_exit=torch.tensor(0.05, dtype=dtype), fringe_at="both", fringe_type="linear_edge", tracking_method="bmadx", diff --git a/tests/test_drift.py b/tests/test_drift.py index f0115216..bb5ba4b8 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -71,7 +71,7 @@ def test_drift_bmadx_tracking(dtype): "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) drift = cheetah.Drift( - length=torch.tensor([1.0]), tracking_method="bmadx", dtype=dtype + length=torch.tensor(1.0), tracking_method="bmadx", dtype=dtype ) # Run tracking diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 309dc07b..dad61da0 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -277,3 +277,46 @@ def test_space_charge_with_ares_astra_beam(): beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") _ = segment.track(beam) + + +def test_space_charge_with_aperture_cutoff(): + """ + Tests that the space charge kick is correctly applied only to surviving particles, + by comparing the results with and without an aperture that results in beam losses. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.2)), + cheetah.Aperture( + x_max=torch.tensor(1e-4), + y_max=torch.tensor(1e-4), + shape="rectangular", + is_active="False", + name="aperture", + ), + cheetah.Drift(length=torch.tensor(0.25)), + cheetah.SpaceChargeKick(effect_length=torch.tensor(0.5)), + cheetah.Drift(length=torch.tensor(0.25)), + ] + ) + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + total_charge=torch.tensor(1e-9), + mu_x=torch.tensor(5e-5), + sigma_px=torch.tensor(1e-4), + sigma_py=torch.tensor(1e-4), + ) + + # Track with inactive aperture + outgoing_beam_without_aperture = segment.track(incoming_beam) + + # Activate the aperture and track the beam + segment.aperture.is_active = True + outgoing_beam_with_aperture = segment.track(incoming_beam) + + # Check that with particle loss the space charge kick is different + assert not torch.allclose( + outgoing_beam_with_aperture.particles, outgoing_beam_without_aperture.particles + ) + # Check that the number of surviving particles is less than the initial number + assert outgoing_beam_with_aperture.survival_probabilities.sum(dim=-1).max() < 10_000 diff --git a/tests/test_statistics.py b/tests/test_statistics.py new file mode 100644 index 00000000..3b603851 --- /dev/null +++ b/tests/test_statistics.py @@ -0,0 +1,73 @@ +import torch + +from cheetah.utils import unbiased_weighted_covariance, unbiased_weighted_variance + + +def test_unbiased_weighted_variance_with_single_element(): + """Test that the variance is NaN when there is only one element.""" + data = torch.tensor([42.0]) + weights = torch.tensor([1.0]) + + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.isnan(computed_variance) + + +def test_unbiased_weighted_variance_with_same_weights(): + """ + Test that the weighted variance with all weights the same equals the unweighted + variance implementated in PyTorch. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + + expected_variance = torch.var(data, unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_different_weights(): + """Test that the variance is computed when some weights are different.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 0.5, 0.5, 0, 0]) + + expected_variance = torch.var(torch.tensor([1.0, 2.0, 3.0]), unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_zero_weights(): + """Test that the variance is NaN when all weights are zero.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) + + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.isnan(computed_variance) + + +def test_unbiased_weighted_variance_with_small_numbers(): + """Test that the variance is correct for small numbers.""" + data = torch.tensor([1e-10, 2e-10, 3e-10, 4e-10, 5e-10], dtype=torch.float32) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32) + + expected_variance = torch.var(data, unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_covariance_reduced_to_variance(): + """ + Test that the covariance computation is correctly reduced to the variance when both + inputs are the same. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 1.0, 1.0, 0.9, 0.9]) + + variance = unbiased_weighted_variance(data, weights) + covariance = unbiased_weighted_covariance(data, data, weights) + + assert torch.allclose(covariance, variance) diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index 3555532b..a9662aa8 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -14,10 +14,10 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) tdc = cheetah.TransverseDeflectingCavity( - length=torch.tensor([1.0]), - voltage=torch.tensor([1e7]), - phase=torch.tensor([0.2], dtype=dtype), - frequency=torch.tensor([1e9]), + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor(0.2, dtype=dtype), + frequency=torch.tensor(1e9), tracking_method="bmadx", dtype=dtype, ) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index b9da0063..aa00598a 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -449,3 +450,47 @@ def test_broadcasting_solenoid_misalignment(): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) + + +@pytest.mark.parametrize("aperture_shape", ["rectangular", "elliptical"]) +def test_vectorized_aperture_broadcasting(aperture_shape): + """ + Test that apertures work in a vectorised setting and that broadcasting rules are + applied correctly. + """ + torch.manual_seed(0) + + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, + sigma_py=torch.tensor(1e-4), + sigma_px=torch.tensor(2e-4), + energy=torch.tensor([154e6, 14e9]), + ) + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.5)), + cheetah.Aperture( + x_max=torch.tensor([[1e-5], [2e-4], [3e-4]]), + y_max=torch.tensor(2e-4), + shape=aperture_shape, + ), + cheetah.Drift(length=torch.tensor(0.5)), + ] + ) + + outgoing = segment.track(incoming) + + # Particle positions are unaffected by the aperture ... only their survival is + assert outgoing.particles.shape == (2, 100_000, 7) + assert outgoing.energy.shape == (2,) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.survival_probabilities.shape == (3, 2, 100_000) + + if aperture_shape == "elliptical": + assert np.allclose( + outgoing.survival_probabilities.sum(dim=-1)[:, 0], [7672, 94523, 99547] + ) + elif aperture_shape == "rectangular": + assert np.allclose( + outgoing.survival_probabilities.sum(dim=-1)[:, 0], [7935, 95400, 99719] + )