diff --git a/CHANGELOG.md b/CHANGELOG.md index e7f7c395..dec4d198 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,10 @@ This is a major release with significant upgrades under the hood of Cheetah. Des ### 🚨 Breaking Changes - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan) +- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). This also comes with the removal of `Beam.empty`. Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) +- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) ### 🚀 Features diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index cdf5d256..71f1b329 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -1,13 +1,12 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -16,8 +15,11 @@ class Aperture(Element): """ Physical aperture. - :param x_max: half size horizontal offset in [m] - :param y_max: half size vertical offset in [m] + NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has + an effect when the aperture is active. + + :param x_max: half size horizontal offset in [m]. + :param y_max: half size vertical offset in [m]. :param shape: Shape of the aperture. Can be "rectangular" or "elliptical". :param is_active: If the aperture actually blocks particles. :param name: Unique identifier of the element. @@ -25,14 +27,15 @@ class Aperture(Element): def __init__( self, - x_max: Optional[Union[torch.Tensor, nn.Parameter]] = None, - y_max: Optional[Union[torch.Tensor, nn.Parameter]] = None, + x_max: Optional[torch.Tensor] = None, + y_max: Optional[torch.Tensor] = None, shape: Literal["rectangular", "elliptical"] = "rectangular", is_active: bool = True, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([x_max, y_max], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -72,7 +75,7 @@ def track(self, incoming: Beam) -> Beam: if not (isinstance(incoming, ParticleBeam) and self.is_active): return incoming - assert self.x_max >= 0 and self.y_max >= 0 + assert torch.all(self.x_max >= 0) and torch.all(self.y_max >= 0) assert self.shape in [ "rectangular", "elliptical", @@ -80,33 +83,28 @@ def track(self, incoming: Beam) -> Beam: if self.shape == "rectangular": survived_mask = torch.logical_and( - torch.logical_and(incoming.x > -self.x_max, incoming.x < self.x_max), - torch.logical_and(incoming.y > -self.y_max, incoming.y < self.y_max), + torch.logical_and( + incoming.x > -self.x_max.unsqueeze(-1), + incoming.x < self.x_max.unsqueeze(-1), + ), + torch.logical_and( + incoming.y > -self.y_max.unsqueeze(-1), + incoming.y < self.y_max.unsqueeze(-1), + ), ) elif self.shape == "elliptical": survived_mask = ( - incoming.x**2 / self.x_max**2 + incoming.y**2 / self.y_max**2 + incoming.x**2 / self.x_max.unsqueeze(-1) ** 2 + + incoming.y**2 / self.y_max.unsqueeze(-1) ** 2 ) <= 1.0 - outgoing_particles = incoming.particles[survived_mask] - - outgoing_particle_charges = incoming.particle_charges[survived_mask] - self.lost_particles = incoming.particles[torch.logical_not(survived_mask)] - - self.lost_particle_charges = incoming.particle_charges[ - torch.logical_not(survived_mask) - ] - - return ( - ParticleBeam( - outgoing_particles, - incoming.energy, - particle_charges=outgoing_particle_charges, - device=outgoing_particles.device, - dtype=outgoing_particles.dtype, - ) - if outgoing_particles.shape[0] > 0 - else ParticleBeam.empty + return ParticleBeam( + particles=incoming.particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities * survived_mask, + device=incoming.particles.device, + dtype=incoming.particles.dtype, ) def split(self, resolution: torch.Tensor) -> list[Element]: diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index d0e636bd..405e8d24 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -37,9 +37,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: ) def track(self, incoming: Beam) -> Beam: - if incoming is Beam.empty: - self.reading = None - elif isinstance(incoming, ParameterBeam): + if isinstance(incoming, ParameterBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) elif isinstance(incoming, ParticleBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index c7a89e05..8e4f9f5e 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -1,16 +1,19 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle from scipy import constants from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.track_methods import base_rmatrix -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -30,14 +33,17 @@ class Cavity(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], - voltage: Optional[Union[torch.Tensor, nn.Parameter]] = None, - phase: Optional[Union[torch.Tensor, nn.Parameter]] = None, - frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor, + voltage: Optional[torch.Tensor] = None, + phase: Optional[torch.Tensor] = None, + frequency: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length, voltage, phase, frequency], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -97,9 +103,7 @@ def track(self, incoming: Beam) -> Beam: :param incoming: Beam of particles entering the element. :return: Beam of particles exiting the element. """ - if incoming is Beam.empty: - return incoming - elif isinstance(incoming, (ParameterBeam, ParticleBeam)): + if isinstance(incoming, (ParameterBeam, ParticleBeam)): return self._track_beam(incoming) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") @@ -238,9 +242,10 @@ def _track_beam(self, incoming: Beam) -> Beam: return outgoing else: # ParticleBeam outgoing = ParticleBeam( - outgoing_particles, - outgoing_energy, + particles=outgoing_particles, + energy=outgoing_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=outgoing_particles.device, dtype=outgoing_particles.dtype, ) @@ -248,13 +253,12 @@ def _track_beam(self, incoming: Beam) -> Beam: def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: """Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0.""" - device = self.length.device - dtype = self.length.dtype + factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype} phi = torch.deg2rad(self.phase) delta_energy = self.voltage * torch.cos(phi) # Comment from Ocelot: Pure pi-standing-wave case - eta = torch.tensor(1.0, device=device, dtype=dtype) + eta = torch.tensor(1.0, **factory_kwargs) Ei = energy / electron_mass_eV Ef = (energy + delta_energy) / electron_mass_eV Ep = (Ef - Ei) / self.length # Derivative of the energy @@ -288,12 +292,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: ) ) - r56 = torch.tensor(0.0) - beta0 = torch.tensor(1.0) - beta1 = torch.tensor(1.0) + r56 = torch.tensor(0.0, **factory_kwargs) + beta0 = torch.tensor(1.0, **factory_kwargs) + beta1 = torch.tensor(1.0, **factory_kwargs) - k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light) - r55_cor = torch.tensor(0.0) + k = 2 * torch.pi * self.frequency / constants.speed_of_light + r55_cor = torch.tensor(0.0, **factory_kwargs) if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if? beta0 = torch.sqrt(1 - 1 / Ei**2) beta1 = torch.sqrt(1 - 1 / Ef**2) @@ -320,7 +324,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: r11, r12, r21, r22, r55_cor, r56, r65, r66 ) - R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1)) + R = torch.eye(7, **factory_kwargs).repeat((*r11.shape, 1, 1)) R[..., 0, 0] = r11 R[..., 0, 1] = r12 R[..., 1, 0] = r21 diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 5baf87d2..bc9b85da 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -1,13 +1,12 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam -from cheetah.utils import UniqueNameGenerator +from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -19,12 +18,13 @@ class CustomTransferMap(Element): def __init__( self, - transfer_map: Union[torch.Tensor, nn.Parameter], + transfer_map: torch.Tensor, length: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 5e919fbf..393c65f6 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -1,16 +1,15 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam from cheetah.track_methods import base_rmatrix, rotation_matrix -from cheetah.utils import UniqueNameGenerator, bmadx +from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -47,23 +46,39 @@ class Dipole(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], - angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, - k1: Optional[Union[torch.Tensor, nn.Parameter]] = None, - e1: Optional[Union[torch.Tensor, nn.Parameter]] = None, - e2: Optional[Union[torch.Tensor, nn.Parameter]] = None, - tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None, - gap: Optional[Union[torch.Tensor, nn.Parameter]] = None, - gap_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, - fringe_integral: Optional[Union[torch.Tensor, nn.Parameter]] = None, - fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor, + angle: Optional[torch.Tensor] = None, + k1: Optional[torch.Tensor] = None, + e1: Optional[torch.Tensor] = None, + e2: Optional[torch.Tensor] = None, + tilt: Optional[torch.Tensor] = None, + gap: Optional[torch.Tensor] = None, + gap_exit: Optional[torch.Tensor] = None, + fringe_integral: Optional[torch.Tensor] = None, + fringe_integral_exit: Optional[torch.Tensor] = None, fringe_at: Literal["neither", "entrance", "exit", "both"] = "both", fringe_type: Literal["linear_edge"] = "linear_edge", tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ): + device, dtype = verify_device_and_dtype( + [ + length, + angle, + k1, + e1, + e2, + tilt, + gap, + gap_exit, + fringe_integral, + fringe_integral_exit, + ], + device, + dtype, + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -203,7 +218,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: # Begin Bmad-X tracking x, px, y, py = bmadx.offset_particle_set( - torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py + torch.zeros_like(self.tilt), + torch.zeros_like(self.tilt), + self.tilt, + x, + px, + y, + py, ) if self.fringe_at == "entrance" or self.fringe_at == "both": @@ -215,7 +236,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: px, py = self._bmadx_fringe_linear("exit", x, px, y, py) x, px, y, py = bmadx.offset_particle_unset( - torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py + torch.zeros_like(self.tilt), + torch.zeros_like(self.tilt), + self.tilt, + x, + px, + y, + py, ) # End of Bmad-X tracking @@ -233,6 +260,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) @@ -240,15 +268,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: def _bmadx_body( self, - x: Union[torch.Tensor, nn.Parameter], - px: Union[torch.Tensor, nn.Parameter], - y: Union[torch.Tensor, nn.Parameter], - py: Union[torch.Tensor, nn.Parameter], - z: Union[torch.Tensor, nn.Parameter], - pz: Union[torch.Tensor, nn.Parameter], - p0c: Union[torch.Tensor, nn.Parameter], + x: torch.Tensor, + px: torch.Tensor, + y: torch.Tensor, + py: torch.Tensor, + z: torch.Tensor, + pz: torch.Tensor, + p0c: torch.Tensor, mc2: float, - ) -> list[Union[torch.Tensor, nn.Parameter]]: + ) -> list[torch.Tensor]: """ Track particle coordinates through bend body. @@ -335,11 +363,11 @@ def _bmadx_body( def _bmadx_fringe_linear( self, location: Literal["entrance", "exit"], - x: Union[torch.Tensor, nn.Parameter], - px: Union[torch.Tensor, nn.Parameter], - y: Union[torch.Tensor, nn.Parameter], - py: Union[torch.Tensor, nn.Parameter], - ) -> list[Union[torch.Tensor, nn.Parameter]]: + x: torch.Tensor, + px: torch.Tensor, + y: torch.Tensor, + py: torch.Tensor, + ) -> list[torch.Tensor]: """ Tracks linear fringe. diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index eb5fb187..ff63f371 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -1,9 +1,8 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional import matplotlib.pyplot as plt import torch from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam @@ -28,11 +27,11 @@ class Drift(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], + length: torch.Tensor, tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -115,6 +114,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index bfe1df7c..e050d106 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -60,9 +60,7 @@ def track(self, incoming: Beam) -> Beam: :param incoming: Beam of particles entering the element. :return: Beam of particles exiting the element. """ - if incoming is Beam.empty: - return incoming - elif isinstance(incoming, ParameterBeam): + if isinstance(incoming, ParameterBeam): tm = self.transfer_map(incoming.energy) mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1) cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1))) @@ -81,6 +79,7 @@ def track(self, incoming: Beam) -> Beam: new_particles, incoming.energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=new_particles.device, dtype=new_particles.dtype, ) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index a00c2cbe..36456aa0 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -1,13 +1,16 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.patches import Rectangle -from torch import nn from cheetah.accelerator.element import Element -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -25,12 +28,13 @@ class HorizontalCorrector(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], - angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor, + angle: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([length, angle], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index db6a559d..99ee3c0a 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -1,16 +1,15 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam from cheetah.track_methods import base_rmatrix, misalignment_matrix -from cheetah.utils import UniqueNameGenerator, bmadx +from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -34,16 +33,19 @@ class Quadrupole(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], - k1: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, - tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor, + k1: Optional[torch.Tensor] = None, + misalignment: Optional[torch.Tensor] = None, + tilt: Optional[torch.Tensor] = None, num_steps: int = 1, tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length, k1, misalignment, tilt], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -185,9 +187,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index b50ef08e..110f9e07 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -1,7 +1,6 @@ -from typing import Optional, Union +from typing import Optional import torch -from torch import nn from cheetah.accelerator.dipole import Dipole from cheetah.utils import UniqueNameGenerator @@ -28,18 +27,18 @@ class RBend(Dipole): def __init__( self, - length: Optional[Union[torch.Tensor, nn.Parameter]], - angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, - k1: Optional[Union[torch.Tensor, nn.Parameter]] = None, - e1: Optional[Union[torch.Tensor, nn.Parameter]] = None, - e2: Optional[Union[torch.Tensor, nn.Parameter]] = None, - tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None, - fringe_integral: Optional[Union[torch.Tensor, nn.Parameter]] = None, - fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, - gap: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: Optional[torch.Tensor], + angle: Optional[torch.Tensor] = None, + k1: Optional[torch.Tensor] = None, + e1: Optional[torch.Tensor] = None, + e2: Optional[torch.Tensor] = None, + tilt: Optional[torch.Tensor] = None, + fringe_integral: Optional[torch.Tensor] = None, + fringe_integral_exit: Optional[torch.Tensor] = None, + gap: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ): super().__init__( length=length, diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index d7b153d6..93132dd1 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -1,15 +1,14 @@ from copy import deepcopy -from typing import Literal, Optional, Union +from typing import Literal, Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle -from torch import nn from torch.distributions import MultivariateNormal from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, kde_histogram_2d +from cheetah.utils import UniqueNameGenerator, kde_histogram_2d, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -43,29 +42,35 @@ class Screen(Element): def __init__( self, - resolution: Optional[Union[torch.Tensor, nn.Parameter]] = None, - pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None, - binning: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + resolution: tuple[int, int] = (1024, 1024), + pixel_size: Optional[torch.Tensor] = None, + binning: int = 1, + misalignment: Optional[torch.Tensor] = None, method: Literal["histogram", "kde"] = "histogram", - kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None, + kde_bandwidth: Optional[torch.Tensor] = None, is_blocking: bool = False, is_active: bool = False, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [pixel_size, misalignment, kde_bandwidth], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) - self.register_buffer( - "resolution", - ( - torch.as_tensor(resolution, **factory_kwargs) - if resolution is not None - else torch.tensor((1024, 1024), **factory_kwargs) - ), - ) + assert method in [ + "histogram", + "kde", + ], f"Invalid method {method}. Must be either 'histogram' or 'kde'." + + self.resolution = resolution + self.binning = binning + self.method = method + self.is_blocking = is_blocking + self.is_active = is_active + self.register_buffer( "pixel_size", ( @@ -74,14 +79,6 @@ def __init__( else torch.tensor((1e-3, 1e-3), **factory_kwargs) ), ) - self.register_buffer( - "binning", - ( - torch.as_tensor(binning, **factory_kwargs) - if binning is not None - else torch.tensor(1, **factory_kwargs) - ), - ) self.register_buffer( "misalignment", ( @@ -94,11 +91,6 @@ def __init__( "length", torch.zeros(self.misalignment.shape[:-1], **factory_kwargs), ) - assert method in [ - "histogram", - "kde", - ], f"Invalid method {method}. Must be either 'histogram' or 'kde'." - self.method = method self.register_buffer( "kde_bandwidth", ( @@ -107,8 +99,6 @@ def __init__( else torch.clone(self.pixel_size[0]) ), ) - self.is_blocking = is_blocking - self.is_active = is_active self.set_read_beam(None) self.cached_reading = None @@ -118,8 +108,11 @@ def is_skippable(self) -> bool: return not self.is_active @property - def effective_resolution(self) -> torch.Tensor: - return self.resolution / self.binning + def effective_resolution(self) -> tuple[int, int]: + return ( + self.resolution[0] // self.binning, + self.resolution[1] // self.binning, + ) @property def effective_pixel_size(self) -> torch.Tensor: @@ -165,6 +158,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) def track(self, incoming: Beam) -> Beam: + # Record the beam only when the screen is active if self.is_active: copy_of_incoming = deepcopy(incoming) @@ -192,7 +186,26 @@ def track(self, incoming: Beam) -> Beam: self.set_read_beam(copy_of_incoming) - return Beam.empty if self.is_blocking else incoming + # Block the beam only when the screen is active and blocking + if self.is_active and self.is_blocking: + if isinstance(incoming, ParameterBeam): + return ParameterBeam( + mu=incoming._mu, + cov=incoming._cov, + energy=incoming.energy, + total_charge=torch.zeros_like(incoming.total_charge), + ) + elif isinstance(incoming, ParticleBeam): + return ParticleBeam( + particles=incoming.particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=torch.zeros_like( + incoming.survival_probabilities + ), + ) + else: + return deepcopy(incoming) @property def reading(self) -> torch.Tensor: @@ -201,9 +214,11 @@ def reading(self) -> torch.Tensor: return self.cached_reading read_beam = self.get_read_beam() - if read_beam is Beam.empty or read_beam is None: + if read_beam is None: image = torch.zeros( - (int(self.effective_resolution[1]), int(self.effective_resolution[0])) + (int(self.effective_resolution[1]), int(self.effective_resolution[0])), + device=self.misalignment.device, + dtype=self.misalignment.dtype, ) elif isinstance(read_beam, ParameterBeam): if torch.numel(read_beam._mu[..., 0]) > 1: @@ -260,16 +275,24 @@ def reading(self) -> torch.Tensor: ) image, _ = torch.histogramdd( - torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges + torch.stack((read_beam.x, read_beam.y)).T, + bins=self.pixel_bin_edges, + weight=read_beam.particle_charges + * read_beam.survival_probabilities, ) image = torch.flipud(image.T) elif self.method == "kde": + weights = read_beam.particle_charges * read_beam.survival_probabilities + broadcasted_x, broadcasted_y, broadcasted_weights = ( + torch.broadcast_tensors(read_beam.x, read_beam.y, weights) + ) image = kde_histogram_2d( - x1=read_beam.x, - x2=read_beam.y, + x1=broadcasted_x, + x2=broadcasted_y, bins1=self.pixel_bin_centers[0], bins2=self.pixel_bin_centers[1], bandwidth=self.kde_bandwidth, + weights=broadcasted_weights, ) # Change the x, y positions image = torch.transpose(image, -2, -1) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index f8faf24c..d59ccd16 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -1,14 +1,17 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element from cheetah.track_methods import misalignment_matrix -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -31,13 +34,16 @@ class Solenoid(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter] = None, - k: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor = None, + k: Optional[torch.Tensor] = None, + misalignment: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length, k, misalignment], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 7bb7d7f3..74808395 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -1,12 +1,12 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import torch from scipy.constants import elementary_charge, epsilon_0, speed_of_light -from torch import nn from cheetah.accelerator.element import Element -from cheetah.particles import Beam, ParticleBeam +from cheetah.particles import ParticleBeam +from cheetah.utils import verify_device_and_dtype class SpaceChargeKick(Element): @@ -47,29 +47,27 @@ class SpaceChargeKick(Element): def __init__( self, - effect_length: Union[torch.Tensor, nn.Parameter], - num_grid_points_x: Union[torch.Tensor, nn.Parameter, int] = 32, - num_grid_points_y: Union[torch.Tensor, nn.Parameter, int] = 32, - num_grid_points_tau: Union[torch.Tensor, nn.Parameter, int] = 32, - grid_extend_x: Union[torch.Tensor, nn.Parameter] = 3, - grid_extend_y: Union[torch.Tensor, nn.Parameter] = 3, - grid_extend_tau: Union[torch.Tensor, nn.Parameter] = 3, + effect_length: torch.Tensor, + num_grid_points_x: int = 32, # TODO: Simplify these to a single tuple? + num_grid_points_y: int = 32, + num_grid_points_tau: int = 32, + grid_extend_x: torch.Tensor = 3, # TODO: Simplify these to a single tensor? + grid_extend_y: torch.Tensor = 3, + grid_extend_tau: torch.Tensor = 3, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([effect_length], device, dtype) self.factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) + self.grid_shape = (num_grid_points_x, num_grid_points_y, num_grid_points_tau) + self.register_buffer( "effect_length", torch.as_tensor(effect_length, **self.factory_kwargs) ) - self.grid_shape = ( - int(num_grid_points_x), - int(num_grid_points_y), - int(num_grid_points_tau), - ) # In multiples of sigma self.register_buffer( "grid_extend_x", torch.as_tensor(grid_extend_x, **self.factory_kwargs) @@ -151,7 +149,8 @@ def _deposit_charge_on_grid( ) # Accumulate the charge contributions - repeated_charges = beam.particle_charges.repeat_interleave( + survived_particle_charges = beam.particle_charges * beam.survival_probabilities + repeated_charges = survived_particle_charges.repeat_interleave( repeats=8, dim=-1 ) # Shape:(..., 8 * num_particles) values = (cell_weights.flatten(start_dim=-2) * repeated_charges)[valid_mask] @@ -547,27 +546,34 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: :param incoming: Beam of particles entering the element. :returns: Beam of particles exiting the element. """ - if incoming is Beam.empty or incoming.particles.shape[0] == 0: - return incoming - elif isinstance(incoming, ParticleBeam): + if isinstance(incoming, ParticleBeam): # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. - # Make sure that the incoming beam has at least one vector dimension - if len(incoming.particles.shape) == 2: - is_incoming_vectorized = False - - vectorized_incoming = ParticleBeam( - particles=incoming.particles.unsqueeze(0), - energy=incoming.energy.unsqueeze(0), - particle_charges=incoming.particle_charges.unsqueeze(0), - device=incoming.particles.device, - dtype=incoming.particles.dtype, - ) - else: - is_incoming_vectorized = True - - vectorized_incoming = incoming + # Make sure that the incoming beam has at least one vector dimension by + # broadcasting with a dummy dimension (1,). + vector_shape = torch.broadcast_shapes( + incoming.particles.shape[:-2], + incoming.energy.shape, + incoming.particle_charges.shape[:-1], + incoming.survival_probabilities.shape[:-1], + (1,), + ) + vectorized_incoming = ParticleBeam( + particles=torch.broadcast_to( + incoming.particles, (*vector_shape, incoming.num_particles, 7) + ), + energy=torch.broadcast_to(incoming.energy, vector_shape), + particle_charges=torch.broadcast_to( + incoming.particle_charges, (*vector_shape, incoming.num_particles) + ), + survival_probabilities=torch.broadcast_to( + incoming.survival_probabilities, + (*vector_shape, incoming.num_particles), + ), + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) flattened_incoming = ParticleBeam( particles=vectorized_incoming.particles.flatten(end_dim=-3), @@ -575,6 +581,9 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=vectorized_incoming.particle_charges.flatten( end_dim=-2 ), + survival_probabilities=( + vectorized_incoming.survival_probabilities.flatten(end_dim=-2) + ), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, ) @@ -589,7 +598,11 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ], dim=-1, ) - cell_size = 2 * grid_dimensions / torch.tensor(self.grid_shape) + cell_size = ( + 2 + * grid_dimensions + / torch.tensor(self.grid_shape, **self.factory_kwargs) + ) dt = flattened_length_effect / ( speed_of_light * flattened_incoming.relativistic_beta ) @@ -609,26 +622,25 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ..., 2 ] * dt.unsqueeze(-1) - if not is_incoming_vectorized: - # Reshape to the original non-vectorised shape - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.squeeze(0), - vectorized_incoming.energy.squeeze(0), - vectorized_incoming.particle_charges.squeeze(0), - vectorized_incoming.particles.device, - vectorized_incoming.particles.dtype, - ) - else: - # Reverse the flattening of the vector dimensions - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.unflatten( - dim=0, sizes=vectorized_incoming.particles.shape[:-2] - ), - vectorized_incoming.energy, - vectorized_incoming.particle_charges, - vectorized_incoming.particles.device, - vectorized_incoming.particles.dtype, - ) + # Reverse the flattening of the vector dimensions + outgoing_vector_shape = torch.broadcast_shapes( + incoming.particles.shape[:-2], + incoming.energy.shape, + incoming.particle_charges.shape[:-1], + incoming.survival_probabilities.shape[:-1], + self.effect_length.shape, + ) + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates=xp_coordinates.reshape( + (*outgoing_vector_shape, incoming.num_particles, 7) + ), + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) + return outgoing else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index e4e9a3cb..6f33417d 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -1,14 +1,13 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants, speed_of_light -from torch import nn from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, bmadx +from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -34,18 +33,21 @@ class TransverseDeflectingCavity(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], - voltage: Optional[Union[torch.Tensor, nn.Parameter]] = None, - phase: Optional[Union[torch.Tensor, nn.Parameter]] = None, - frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None, - misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, - tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor, + voltage: Optional[torch.Tensor] = None, + phase: Optional[torch.Tensor] = None, + frequency: Optional[torch.Tensor] = None, + misalignment: Optional[torch.Tensor] = None, + tilt: Optional[torch.Tensor] = None, num_steps: int = 1, tracking_method: Literal["bmadx"] = "bmadx", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length, voltage, phase, frequency, misalignment, tilt], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -87,7 +89,7 @@ def __init__( ( torch.as_tensor(tilt, **factory_kwargs) if tilt is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.num_steps = num_steps @@ -205,9 +207,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index e7304870..290ab9a5 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -1,10 +1,9 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element from cheetah.utils import UniqueNameGenerator @@ -28,11 +27,11 @@ class Undulator(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], + length: torch.Tensor, is_active: bool = False, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index be5ba4e4..78e78fb3 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -1,14 +1,17 @@ -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.patches import Rectangle from scipy.constants import physical_constants -from torch import nn from cheetah.accelerator.element import Element -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -28,12 +31,13 @@ class VerticalCorrector(Element): def __init__( self, - length: Union[torch.Tensor, nn.Parameter], - angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + length: torch.Tensor, + angle: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([length, angle], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/converters/nxtables.py b/cheetah/converters/nxtables.py index 03af4e93..6531e170 100644 --- a/cheetah/converters/nxtables.py +++ b/cheetah/converters/nxtables.py @@ -62,51 +62,51 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: elif class_name == "BSCX": element = cheetah.Screen( name=name, - resolution=torch.tensor((2464, 2056)), + resolution=(2464, 2056), pixel_size=torch.tensor((0.00343e-3, 0.00247e-3)), - binning=torch.tensor(1), + binning=1, ) elif class_name == "BSCR": element = cheetah.Screen( name=name, - resolution=torch.tensor([2448, 2040]), + resolution=(2448, 2040), pixel_size=torch.tensor([3.5488e-6, 2.5003e-6]), - binning=torch.tensor(1), + binning=1, ) elif class_name == "BSCM": element = cheetah.Screen( # TODO: Ask for actual parameters name=name, - resolution=torch.tensor([2448, 2040]), + resolution=(2448, 2040), pixel_size=torch.tensor([3.5488e-6, 2.5003e-6]), - binning=torch.tensor(1), + binning=1, ) elif class_name == "BSCO": element = cheetah.Screen( # TODO: Ask for actual parameters name=name, - resolution=torch.tensor([2448, 2040]), + resolution=(2448, 2040), pixel_size=torch.tensor([3.5488e-6, 2.5003e-6]), - binning=torch.tensor(1), + binning=1, ) elif class_name == "BSCA": element = cheetah.Screen( # TODO: Ask for actual parameters name=name, - resolution=torch.tensor([2448, 2040]), + resolution=(2448, 2040), pixel_size=torch.tensor([3.5488e-6, 2.5003e-6]), - binning=torch.tensor(1), + binning=1, ) elif class_name == "BSCE": element = cheetah.Screen( # TODO: Ask for actual parameters name=name, - resolution=torch.tensor((2464, 2056)), + resolution=(2464, 2056), pixel_size=torch.tensor((0.00998e-3, 0.00715e-3)), - binning=torch.tensor(1), + binning=1, ) elif class_name == "SCRD": element = cheetah.Screen( # TODO: Ask for actual parameters name=name, - resolution=torch.tensor((2464, 2056)), + resolution=(2464, 2056), pixel_size=torch.tensor((0.00998e-3, 0.00715e-3)), - binning=torch.tensor(1), + binning=1, ) elif class_name == "BPMG": element = cheetah.BPM(name=name) diff --git a/cheetah/converters/ocelot.py b/cheetah/converters/ocelot.py index 27ea9221..51fc444f 100644 --- a/cheetah/converters/ocelot.py +++ b/cheetah/converters/ocelot.py @@ -140,7 +140,7 @@ def convert_element_to_cheetah( " properties." ) return cheetah.Screen( - resolution=torch.tensor([2448, 2040]), + resolution=(2448, 2040), pixel_size=torch.tensor([3.5488e-6, 2.5003e-6]), name=element.id, device=device, diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index a0ce6c91..5a331dc5 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -33,8 +33,6 @@ class directly, but use one of the subclasses. :math:`\Delta E = E - E_0` """ - empty = "I'm an empty beam!" - @classmethod @abstractmethod def from_parameters( @@ -55,7 +53,7 @@ def from_parameters( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "Beam": """ Create beam that with given beam parameters. @@ -101,7 +99,7 @@ def from_twiss( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "Beam": """ Create a beam from twiss parameters. @@ -127,7 +125,7 @@ def from_twiss( @classmethod @abstractmethod - def from_ocelot(cls, parray) -> "Beam": + def from_ocelot(cls, parray, device=None, dtype=None) -> "Beam": """ Convert an Ocelot ParticleArray `parray` to a Cheetah Beam. """ @@ -135,7 +133,7 @@ def from_ocelot(cls, parray) -> "Beam": @classmethod @abstractmethod - def from_astra(cls, path: str, **kwargs) -> "Beam": + def from_astra(cls, path: str, device=None, dtype=None) -> "Beam": """Load an Astra particle distribution as a Cheetah Beam.""" raise NotImplementedError @@ -154,7 +152,7 @@ def transformed_to( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "Beam": """ Create version of this beam that is transformed to new beam parameters. @@ -179,6 +177,9 @@ def transformed_to( CUDA GPU is selected if available. The CPU is used otherwise. :param dtype: Data type of the transformed beam. """ + device = device if device is not None else self.mu_x.device + dtype = dtype if dtype is not None else self.mu_x.dtype + # Figure out vector dimensions of the original beam and check that passed # arguments have the same vector dimensions. shape = self.mu_x.shape diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index ab509a11..e4bec7ef 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -5,6 +5,7 @@ from cheetah.particles.beam import Beam from cheetah.particles.particle_beam import ParticleBeam +from cheetah.utils import verify_device_and_dtype class ParameterBeam(Beam): @@ -26,8 +27,11 @@ def __init__( energy: torch.Tensor, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [mu, cov, energy, total_charge], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -62,24 +66,66 @@ def from_parameters( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParameterBeam": + # Extract device and dtype from given arguments + device, dtype = verify_device_and_dtype( + [ + mu_x, + mu_px, + mu_y, + mu_py, + sigma_x, + sigma_px, + sigma_y, + sigma_py, + sigma_tau, + sigma_p, + cor_x, + cor_y, + cor_tau, + energy, + total_charge, + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} + # Set default values without function call in function signature - mu_x = mu_x if mu_x is not None else torch.tensor(0.0) - mu_px = mu_px if mu_px is not None else torch.tensor(0.0) - mu_y = mu_y if mu_y is not None else torch.tensor(0.0) - mu_py = mu_py if mu_py is not None else torch.tensor(0.0) - sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) - cor_x = cor_x if cor_x is not None else torch.tensor(0.0) - cor_y = cor_y if cor_y is not None else torch.tensor(0.0) - cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) - energy = energy if energy is not None else torch.tensor(1e8) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0, **factory_kwargs) + sigma_x = ( + sigma_x if sigma_x is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_px = ( + sigma_px if sigma_px is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_y = ( + sigma_y if sigma_y is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_py = ( + sigma_py if sigma_py is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0, **factory_kwargs) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0, **factory_kwargs) + cor_tau = ( + cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) mu = torch.stack( @@ -116,7 +162,7 @@ def from_parameters( cor_tau, sigma_p, ) - cov = torch.zeros(*sigma_x.shape, 7, 7) + cov = torch.zeros(*sigma_x.shape, 7, 7, **factory_kwargs) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -154,12 +200,11 @@ def from_twiss( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParameterBeam": - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ + # Extract device and dtype from given arguments + device, dtype = verify_device_and_dtype( + [ beta_x, alpha_x, emittance_x, @@ -171,32 +216,45 @@ def from_twiss( cor_tau, energy, total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature - beta_x = beta_x if beta_x is not None else torch.full(shape, 1.0) - alpha_x = alpha_x if alpha_x is not None else torch.full(shape, 0.0) + beta_x = beta_x if beta_x is not None else torch.tensor(1.0, **factory_kwargs) + alpha_x = ( + alpha_x if alpha_x is not None else torch.tensor(0.0, **factory_kwargs) + ) emittance_x = ( - emittance_x if emittance_x is not None else torch.full(shape, 7.1971891e-13) + emittance_x + if emittance_x is not None + else torch.tensor(7.1971891e-13, **factory_kwargs) + ) + beta_y = beta_y if beta_y is not None else torch.tensor(1.0, **factory_kwargs) + alpha_y = ( + alpha_y if alpha_y is not None else torch.tensor(0.0, **factory_kwargs) ) - beta_y = beta_y if beta_y is not None else torch.full(shape, 1.0) - alpha_y = alpha_y if alpha_y is not None else torch.full(shape, 0.0) emittance_y = ( - emittance_y if emittance_y is not None else torch.full(shape, 7.1971891e-13) + emittance_y + if emittance_y is not None + else torch.tensor(7.1971891e-13, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) ) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + cor_tau = ( + cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) ) assert torch.all( @@ -287,7 +345,7 @@ def transformed_to( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParameterBeam": """ Create version of this beam that is transformed to new beam parameters. diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 5051002b..77aafb9f 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -6,7 +6,12 @@ from torch.distributions import MultivariateNormal from cheetah.particles.beam import Beam -from cheetah.utils import elementwise_linspace +from cheetah.utils import ( + elementwise_linspace, + unbiased_weighted_covariance, + unbiased_weighted_std, + verify_device_and_dtype, +) speed_of_light = torch.tensor(constants.speed_of_light) # In m/s electron_mass = torch.tensor(constants.electron_mass) # In kg @@ -21,7 +26,10 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. - :param total_charge: Total charge of the beam in C. + :param particle_charges: Charges of the macroparticles in the beam in C. + :param survival_probabilities: Vector of probabilities that each particle has + survived (i.e. not been lost), where 1.0 means the particle has survived and + 0.0 means the particle has been lost. Defaults to ones. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. :param dtype: Data type of the generated particles. @@ -32,11 +40,15 @@ def __init__( particles: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, + survival_probabilities: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: - super().__init__() + device, dtype = verify_device_and_dtype( + [particles, energy, particle_charges], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() assert ( particles.shape[-2] > 0 and particles.shape[-1] == 7 @@ -48,10 +60,18 @@ def __init__( ( particle_charges.to(**factory_kwargs) if particle_charges is not None - else torch.zeros(particles.shape[:2], **factory_kwargs) + else torch.zeros(particles.shape[-2], **factory_kwargs) ), ) self.register_buffer("energy", energy.to(**factory_kwargs)) + self.register_buffer( + "survival_probabilities", + ( + survival_probabilities.to(**factory_kwargs) + if survival_probabilities is not None + else torch.ones(particles.shape[-2], **factory_kwargs) + ), + ) @classmethod def from_parameters( @@ -73,7 +93,7 @@ def from_parameters( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": """ Generate Cheetah Beam of random particles. @@ -97,30 +117,71 @@ def from_parameters( :param cor_y: Correlation between y and py. :param cor_tau: Correlation between s and p. :param energy: Energy of the beam in eV. - :total_charge: Total charge of the beam in C. + :param total_charge: Total charge of the beam in C. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. :param dtype: Data type of the generated particles. """ + # Extract device and dtype from given arguments + device, dtype = verify_device_and_dtype( + [ + mu_x, + mu_px, + mu_y, + mu_py, + sigma_x, + sigma_px, + sigma_y, + sigma_py, + sigma_tau, + sigma_p, + cor_x, + cor_y, + cor_tau, + energy, + total_charge, + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature - mu_x = mu_x if mu_x is not None else torch.tensor(0.0) - mu_px = mu_px if mu_px is not None else torch.tensor(0.0) - mu_y = mu_y if mu_y is not None else torch.tensor(0.0) - mu_py = mu_py if mu_py is not None else torch.tensor(0.0) - sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) - cor_x = cor_x if cor_x is not None else torch.tensor(0.0) - cor_y = cor_y if cor_y is not None else torch.tensor(0.0) - cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) - energy = energy if energy is not None else torch.tensor(1e8) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0, **factory_kwargs) + sigma_x = ( + sigma_x if sigma_x is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_px = ( + sigma_px if sigma_px is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_y = ( + sigma_y if sigma_y is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_py = ( + sigma_py if sigma_py is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0, **factory_kwargs) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0, **factory_kwargs) + cor_tau = ( + cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) particle_charges = ( - torch.ones((*total_charge.shape, num_particles)) + torch.ones((*total_charge.shape, num_particles), **factory_kwargs) * total_charge.unsqueeze(-1) / num_particles ) @@ -152,7 +213,7 @@ def from_parameters( cor_tau, sigma_p, ) - cov = torch.zeros(*sigma_x.shape, 6, 6) + cov = torch.zeros(*sigma_x.shape, 6, 6, **factory_kwargs) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -166,7 +227,10 @@ def from_parameters( cov[..., 5, 4] = cor_tau cov[..., 5, 5] = sigma_p**2 - particles = torch.ones((*mean.shape[:-1], num_particles, 7)) + vector_shape = torch.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean = mean.expand(*vector_shape, 6) + cov = cov.expand(*vector_shape, 6, 6) + particles = torch.ones((*vector_shape, num_particles, 7), **factory_kwargs) distributions = [ MultivariateNormal(sample_mean, covariance_matrix=sample_cov) for sample_mean, sample_cov in zip(mean.view(-1, 6), cov.view(-1, 6, 6)) @@ -174,7 +238,7 @@ def from_parameters( particles[..., :6] = torch.stack( [distribution.sample((num_particles,)) for distribution in distributions], dim=0, - ).view(*particles.shape[:-2], num_particles, 6) + ).view(*vector_shape, num_particles, 6) return cls( particles, @@ -187,7 +251,7 @@ def from_parameters( @classmethod def from_twiss( cls, - num_particles: int = 1_000_000, + num_particles: int = 100_000, beta_x: Optional[torch.Tensor] = None, alpha_x: Optional[torch.Tensor] = None, emittance_x: Optional[torch.Tensor] = None, @@ -200,12 +264,11 @@ def from_twiss( cor_tau: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ + # Extract device and dtype from given arguments + device, dtype = verify_device_and_dtype( + [ beta_x, alpha_x, emittance_x, @@ -217,28 +280,45 @@ def from_twiss( sigma_p, cor_tau, total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature - beta_x = beta_x if beta_x is not None else torch.full(shape, 0.0) - alpha_x = alpha_x if alpha_x is not None else torch.full(shape, 0.0) - emittance_x = emittance_x if emittance_x is not None else torch.full(shape, 0.0) - beta_y = beta_y if beta_y is not None else torch.full(shape, 0.0) - alpha_y = alpha_y if alpha_y is not None else torch.full(shape, 0.0) - emittance_y = emittance_y if emittance_y is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) + beta_x = beta_x if beta_x is not None else torch.tensor(0.0, **factory_kwargs) + alpha_x = ( + alpha_x if alpha_x is not None else torch.tensor(0.0, **factory_kwargs) + ) + emittance_x = ( + emittance_x + if emittance_x is not None + else torch.tensor(7.1971891e-13, **factory_kwargs) + ) + beta_y = beta_y if beta_y is not None else torch.tensor(0.0, **factory_kwargs) + alpha_y = ( + alpha_y if alpha_y is not None else torch.tensor(0.0, **factory_kwargs) + ) + emittance_y = ( + emittance_y + if emittance_y is not None + else torch.tensor(7.1971891e-13, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + cor_tau = ( + cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs) + ) total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) ) sigma_x = torch.sqrt(beta_x * emittance_x) @@ -250,10 +330,10 @@ def from_twiss( return cls.from_parameters( num_particles=num_particles, - mu_x=torch.full(shape, 0.0), - mu_px=torch.full(shape, 0.0), - mu_y=torch.full(shape, 0.0), - mu_py=torch.full(shape, 0.0), + mu_x=torch.tensor(0.0, **factory_kwargs), + mu_px=torch.tensor(0.0, **factory_kwargs), + mu_y=torch.tensor(0.0, **factory_kwargs), + mu_py=torch.tensor(0.0, **factory_kwargs), sigma_x=sigma_x, sigma_px=sigma_px, sigma_y=sigma_y, @@ -272,7 +352,7 @@ def from_twiss( @classmethod def uniform_3d_ellipsoid( cls, - num_particles: int = 1_000_000, + num_particles: int = 100_000, radius_x: Optional[torch.Tensor] = None, radius_y: Optional[torch.Tensor] = None, radius_tau: Optional[torch.Tensor] = None, @@ -282,7 +362,7 @@ def uniform_3d_ellipsoid( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ): """ Generate a particle beam with spatially uniformly distributed particles inside @@ -312,11 +392,9 @@ def uniform_3d_ellipsoid( :return: ParticleBeam with uniformly distributed particles inside an ellipsoid. """ - - # Figure out if arguments were passed, figure out their shape - not_nones = [ - argument - for argument in [ + # Extract device and dtype from given arguments + device, dtype = verify_device_and_dtype( + [ radius_x, radius_y, radius_tau, @@ -325,49 +403,51 @@ def uniform_3d_ellipsoid( sigma_p, energy, total_charge, - ] - if argument is not None - ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." - - # Expand to vectorised version for beam creation - vector_shape = shape if len(shape) > 0 else torch.Size([1]) + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature # NOTE that this does not need to be done for values that are passed to the # Gaussian beam generation. radius_x = ( - radius_x.expand(vector_shape) - if radius_x is not None - else torch.full(vector_shape, 1e-3) + radius_x if radius_x is not None else torch.tensor(1e-3, **factory_kwargs) ) radius_y = ( - radius_y.expand(vector_shape) - if radius_y is not None - else torch.full(vector_shape, 1e-3) + radius_y if radius_y is not None else torch.tensor(1e-3, **factory_kwargs) ) radius_tau = ( - radius_tau.expand(vector_shape) + radius_tau if radius_tau is not None - else torch.full(vector_shape, 1e-3) + else torch.tensor(1e-3, **factory_kwargs) ) - # Generate x, y and ss within the ellipsoid - flattened_x = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) - flattened_y = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) - flattened_tau = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) + # Generate x, y and tau within the ellipsoid + # Broadcasting with (1,) is a hack to make the loop work. Interestingly it + # this does not break the assigments into the non-vectorised particle tensor of + # the beam object. + vector_shape = torch.broadcast_shapes( + radius_x.shape, radius_y.shape, radius_tau.shape, (1,) + ) + flattened_x = torch.empty( + *vector_shape, num_particles, **factory_kwargs + ).flatten(end_dim=-2) + flattened_y = torch.empty( + *vector_shape, num_particles, **factory_kwargs + ).flatten(end_dim=-2) + flattened_tau = torch.empty( + *vector_shape, num_particles, **factory_kwargs + ).flatten(end_dim=-2) for i, (r_x, r_y, r_tau) in enumerate( zip(radius_x.flatten(), radius_y.flatten(), radius_tau.flatten()) ): num_successful = 0 while num_successful < num_particles: - x = (torch.rand(num_particles) - 0.5) * 2 * r_x - y = (torch.rand(num_particles) - 0.5) * 2 * r_y - tau = (torch.rand(num_particles) - 0.5) * 2 * r_tau + x = (torch.rand(num_particles, **factory_kwargs) - 0.5) * 2 * r_x + y = (torch.rand(num_particles, **factory_kwargs) - 0.5) * 2 * r_y + tau = (torch.rand(num_particles, **factory_kwargs) - 0.5) * 2 * r_tau is_in_ellipsoid = x**2 / r_x**2 + y**2 / r_y**2 + tau**2 / r_tau**2 < 1 num_to_add = min(num_particles - num_successful, is_in_ellipsoid.sum()) @@ -387,10 +467,13 @@ def uniform_3d_ellipsoid( # Generate an uncorrelated Gaussian beam beam = cls.from_parameters( num_particles=num_particles, - mu_px=torch.full(shape, 0.0), - mu_py=torch.full(shape, 0.0), + mu_px=torch.tensor(0.0, **factory_kwargs), + mu_py=torch.tensor(0.0, **factory_kwargs), + sigma_x=radius_x, # Only a placeholder, will be overwritten sigma_px=sigma_px, + sigma_y=radius_y, # Only a placeholder, will be overwritten sigma_py=sigma_py, + sigma_tau=radius_tau, # Only a placeholder, will be overwritten sigma_p=sigma_p, energy=energy, total_charge=total_charge, @@ -399,9 +482,9 @@ def uniform_3d_ellipsoid( ) # Replace the spatial coordinates with the generated ones - beam.x = flattened_x.view(*shape, num_particles) - beam.y = flattened_y.view(*shape, num_particles) - beam.tau = flattened_tau.view(*shape, num_particles) + beam.x = flattened_x.view(*vector_shape, num_particles) + beam.y = flattened_y.view(*vector_shape, num_particles) + beam.tau = flattened_tau.view(*vector_shape, num_particles) return beam @@ -422,7 +505,7 @@ def make_linspaced( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": """ Generate Cheetah Beam of *n* linspaced particles. @@ -446,22 +529,58 @@ def make_linspaced( CUDA GPU is selected if available. The CPU is used otherwise. :param dtype: Data type of the generated particles. """ + # Extract device and dtype from given arguments + device, dtype = verify_device_and_dtype( + [ + mu_x, + mu_px, + mu_y, + mu_py, + sigma_x, + sigma_px, + sigma_y, + sigma_py, + sigma_tau, + sigma_p, + energy, + total_charge, + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature - mu_x = mu_x if mu_x is not None else torch.tensor(0.0) - mu_px = mu_px if mu_px is not None else torch.tensor(0.0) - mu_y = mu_y if mu_y is not None else torch.tensor(0.0) - mu_py = mu_py if mu_py is not None else torch.tensor(0.0) - sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) - energy = energy if energy is not None else torch.tensor(1e8) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0, **factory_kwargs) + sigma_x = ( + sigma_x if sigma_x is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_px = ( + sigma_px if sigma_px is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_y = ( + sigma_y if sigma_y is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_py = ( + sigma_py if sigma_py is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) particle_charges = ( - torch.ones((*total_charge.shape, num_particles)) + torch.ones((*total_charge.shape, num_particles), **factory_kwargs) * total_charge.unsqueeze(-1) / num_particles ) @@ -478,7 +597,7 @@ def make_linspaced( sigma_tau.shape, sigma_p.shape, ) - particles = torch.ones((*vector_shape, num_particles, 7)) + particles = torch.ones((*vector_shape, num_particles, 7), **factory_kwargs) particles[..., 0] = elementwise_linspace( mu_x - sigma_x, mu_x + sigma_x, num_particles @@ -553,7 +672,7 @@ def transformed_to( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": """ Create version of this beam that is transformed to new beam parameters. @@ -693,6 +812,7 @@ def from_xyz_pxpypz( xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, + survival_probabilities: Optional[torch.Tensor] = None, device=None, dtype=torch.float32, ) -> torch.Tensor: @@ -705,6 +825,7 @@ def from_xyz_pxpypz( particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, + survival_probabilities=survival_probabilities, device=device, dtype=dtype, ) @@ -770,15 +891,26 @@ def __len__(self) -> int: @property def total_charge(self) -> torch.Tensor: - return torch.sum(self.particle_charges, dim=-1) + """Total charge of the beam in C, taking into account particle losses.""" + return torch.sum(self.particle_charges * self.survival_probabilities, dim=-1) @property def num_particles(self) -> int: + """ + Length of the macroparticle array. + + NOTE: This does not account for lost particles. + """ return self.particles.shape[-2] + @property + def num_particles_survived(self) -> torch.Tensor: + """Number of macroparticles that have survived.""" + return self.survival_probabilities.sum(dim=-1) + @property def x(self) -> Optional[torch.Tensor]: - return self.particles[..., 0] if self is not Beam.empty else None + return self.particles[..., 0] @x.setter def x(self, value: torch.Tensor) -> None: @@ -786,15 +918,27 @@ def x(self, value: torch.Tensor) -> None: @property def mu_x(self) -> Optional[torch.Tensor]: - return self.x.mean(dim=-1) if self is not Beam.empty else None + """ + Mean of the :math:`x` coordinates of the particles, weighted by their + survival probability. + """ + return torch.sum( + (self.x * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_x(self) -> Optional[torch.Tensor]: - return self.x.std(dim=-1) if self is not Beam.empty else None + """ + Standard deviation of the :math:`x` coordinates of the particles, weighted + by their survival probability. + """ + return unbiased_weighted_std( + self.x, weights=self.survival_probabilities, dim=-1 + ) @property def px(self) -> Optional[torch.Tensor]: - return self.particles[..., 1] if self is not Beam.empty else None + return self.particles[..., 1] @px.setter def px(self, value: torch.Tensor) -> None: @@ -802,15 +946,27 @@ def px(self, value: torch.Tensor) -> None: @property def mu_px(self) -> Optional[torch.Tensor]: - return self.px.mean(dim=-1) if self is not Beam.empty else None + """ + Mean of the :math:`px` coordinates of the particles, weighted by their + survival probability. + """ + return torch.sum( + (self.px * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_px(self) -> Optional[torch.Tensor]: - return self.px.std(dim=-1) if self is not Beam.empty else None + """ + Standard deviation of the :math:`px` coordinates of the particles, weighted + by their survival probability. + """ + return unbiased_weighted_std( + self.px, weights=self.survival_probabilities, dim=-1 + ) @property def y(self) -> Optional[torch.Tensor]: - return self.particles[..., 2] if self is not Beam.empty else None + return self.particles[..., 2] @y.setter def y(self, value: torch.Tensor) -> None: @@ -818,15 +974,19 @@ def y(self, value: torch.Tensor) -> None: @property def mu_y(self) -> Optional[float]: - return self.y.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.y * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_y(self) -> Optional[torch.Tensor]: - return self.y.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.y, weights=self.survival_probabilities, dim=-1 + ) @property def py(self) -> Optional[torch.Tensor]: - return self.particles[..., 3] if self is not Beam.empty else None + return self.particles[..., 3] @py.setter def py(self, value: torch.Tensor) -> None: @@ -834,15 +994,19 @@ def py(self, value: torch.Tensor) -> None: @property def mu_py(self) -> Optional[torch.Tensor]: - return self.py.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.py * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_py(self) -> Optional[torch.Tensor]: - return self.py.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.py, weights=self.survival_probabilities, dim=-1 + ) @property def tau(self) -> Optional[torch.Tensor]: - return self.particles[..., 4] if self is not Beam.empty else None + return self.particles[..., 4] @tau.setter def tau(self, value: torch.Tensor) -> None: @@ -850,15 +1014,19 @@ def tau(self, value: torch.Tensor) -> None: @property def mu_tau(self) -> Optional[torch.Tensor]: - return self.tau.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.tau * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_tau(self) -> Optional[torch.Tensor]: - return self.tau.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.tau, weights=self.survival_probabilities, dim=-1 + ) @property def p(self) -> Optional[torch.Tensor]: - return self.particles[..., 5] if self is not Beam.empty else None + return self.particles[..., 5] @p.setter def p(self, value: torch.Tensor) -> None: @@ -866,24 +1034,34 @@ def p(self, value: torch.Tensor) -> None: @property def mu_p(self) -> Optional[torch.Tensor]: - return self.p.mean(dim=-1) if self is not Beam.empty else None + return torch.sum( + (self.p * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_p(self) -> Optional[torch.Tensor]: - return self.p.std(dim=-1) if self is not Beam.empty else None + return unbiased_weighted_std( + self.p, weights=self.survival_probabilities, dim=-1 + ) @property def sigma_xpx(self) -> torch.Tensor: - return torch.mean( - (self.x - self.mu_x.unsqueeze(-1)) * (self.px - self.mu_px.unsqueeze(-1)), - dim=-1, + r""" + Returns the covariance between x and px. :math:`\sigma_{x, px}^2`. + It is weighted by the survival probability of the particles. + """ + return unbiased_weighted_covariance( + self.x, self.px, weights=self.survival_probabilities, dim=-1 ) @property def sigma_ypy(self) -> torch.Tensor: - return torch.mean( - (self.y - self.mu_y.unsqueeze(-1)) * (self.py - self.mu_py.unsqueeze(-1)), - dim=-1, + r""" + Returns the covariance between y and py. :math:`\sigma_{y, py}^2`. + It is weighted by the survival probability of the particles. + """ + return unbiased_weighted_covariance( + self.y, self.py, weights=self.survival_probabilities, dim=-1 ) @property diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index a29d9ae9..ba74ef57 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -1,6 +1,12 @@ from . import bmadx # noqa: F401 +from .argument_verification import verify_device_and_dtype # noqa: F401 from .device import is_mps_available_and_functional # noqa: F401 from .elementwise_linspace import elementwise_linspace # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 from .physics import compute_relativistic_factors # noqa: F401 +from .statistics import ( # noqa: F401 + unbiased_weighted_covariance, + unbiased_weighted_std, + unbiased_weighted_variance, +) from .unique_name_generator import UniqueNameGenerator # noqa: F401 diff --git a/cheetah/utils/argument_verification.py b/cheetah/utils/argument_verification.py new file mode 100644 index 00000000..23f8cc68 --- /dev/null +++ b/cheetah/utils/argument_verification.py @@ -0,0 +1,56 @@ +from typing import Optional + +import torch + + +def are_all_the_same_device(tensors: list[torch.Tensor]) -> torch.device: + """ + Determines whether all arguments are on the same device and, if so, returns that + device. If no arguments are passed, global default PyTorch device is returned. + """ + if len(tensors) > 1: + assert all( + argument.device == tensors[0].device for argument in tensors + ), "All tensors must be on the same device." + + return tensors[0].device if len(tensors) > 0 else torch.get_default_device() + + +def are_all_the_same_dtype(tensors: list[torch.Tensor]) -> torch.dtype: + """ + Determines whether all arguments have the same dtype and, if so, returns that dtype. + If no arguments are passed, global default PyTorch dtype is returned. + """ + if len(tensors) > 1: + assert all( + argument.dtype == tensors[0].dtype for argument in tensors + ), "All arguments must have the same dtype." + + return tensors[0].dtype if len(tensors) > 0 else torch.get_default_dtype() + + +def verify_device_and_dtype( + tensors: list[Optional[torch.Tensor]], + desired_device: Optional[torch.device], + desired_dtype: Optional[torch.dtype], +) -> tuple[torch.device, torch.dtype]: + """ + Verifies that a unique device and dtype can be determined from the passed tensors + and the optional desired device and dtype. If no desired values are requested, + then all tensors (if they are not `None`) must have the same device and dtype. + + If all verifications pass, this function returns the determined device and dtype. + """ + not_nones = [tensor for tensor in tensors if tensor is not None] + + chosen_device = ( + desired_device + if desired_device is not None + else are_all_the_same_device(not_nones) + ) + chosen_dtype = ( + desired_dtype + if desired_dtype is not None + else are_all_the_same_dtype(not_nones) + ) + return (chosen_device, chosen_dtype) diff --git a/cheetah/utils/statistics.py b/cheetah/utils/statistics.py new file mode 100644 index 00000000..adfbeba2 --- /dev/null +++ b/cheetah/utils/statistics.py @@ -0,0 +1,62 @@ +import torch + + +def unbiased_weighted_covariance( + input1: torch.Tensor, input2: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted covariance of two tensors. + + :param input1: Input tensor 1. (..., sample_size) + :param input2: Input tensor 2. (..., sample_size) + :param weights: Weights tensor. (..., sample_size) + :param dim: Dimension along which to compute the covariance. + :return: Unbiased weighted covariance. (..., 2, 2) + """ + weighted_mean1 = torch.sum(input1 * weights, dim=dim) / torch.sum(weights, dim=dim) + weighted_mean2 = torch.sum(input2 * weights, dim=dim) / torch.sum(weights, dim=dim) + correction_factor = torch.sum(weights, dim=dim) - torch.sum( + weights**2, dim=dim + ) / torch.sum(weights, dim=dim) + covariance = torch.sum( + weights + * (input1 - weighted_mean1.unsqueeze(-1)) + * (input2 - weighted_mean2.unsqueeze(-1)), + dim=dim, + ) / (correction_factor) + return covariance + + +def unbiased_weighted_variance( + input: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted variance of a tensor. + + :param input: Input tensor. + :param weights: Weights tensor. + :param dim: Dimension along which to compute the variance. + :return: Unbiased weighted variance. + """ + weighted_mean = torch.sum(input * weights, dim=dim) / torch.sum(weights, dim=dim) + correction_factor = torch.sum(weights, dim=dim) - torch.sum( + weights**2, dim=dim + ) / torch.sum(weights, dim=dim) + variance = torch.sum( + weights * (input - weighted_mean.unsqueeze(-1)) ** 2, dim=dim + ) / (correction_factor) + return variance + + +def unbiased_weighted_std( + input: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted standard deviation of a tensor. + + :param input: Input tensor. + :param weights: Weights tensor. + :param dim: Dimension along which to compute the standard deviation. + :return: Unbiased weighted standard deviation. + """ + return torch.sqrt(unbiased_weighted_variance(input, weights, dim=dim)) diff --git a/tests/resources/bmadx/incoming.pt b/tests/resources/bmadx/incoming.pt index c901e8aa..5279a3d6 100644 Binary files a/tests/resources/bmadx/incoming.pt and b/tests/resources/bmadx/incoming.pt differ diff --git a/tests/test_compare_beam_type.py b/tests/test_compare_beam_type.py index 6a6d3e28..b995a843 100644 --- a/tests/test_compare_beam_type.py +++ b/tests/test_compare_beam_type.py @@ -241,6 +241,7 @@ def test_cavity_from_twiss(): # Particle beam incoming_particle_beam = cheetah.ParticleBeam.from_twiss( + num_particles=1_000_000, beta_x=torch.tensor(5.91253677), alpha_x=torch.tensor(3.55631308), beta_y=torch.tensor(5.91253677), diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 1478454e..bc1ab083 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -176,7 +176,10 @@ def test_aperture(): navigator.activate_apertures() _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) - assert outgoing_beam.num_particles == outgoing_p_array.rparticles.shape[1] + assert ( + int(outgoing_beam.num_particles_survived) + == outgoing_p_array.rparticles.shape[1] + ) def test_aperture_elliptical(): @@ -215,7 +218,13 @@ def test_aperture_elliptical(): navigator.activate_apertures() _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) - assert outgoing_beam.num_particles == outgoing_p_array.rparticles.shape[1] + assert ( + int(outgoing_beam.num_particles_survived) + == outgoing_p_array.rparticles.shape[1] + ) + + assert np.allclose(outgoing_beam.mu_x.cpu().numpy(), outgoing_p_array.x().mean()) + assert np.allclose(outgoing_beam.mu_px.cpu().numpy(), outgoing_p_array.px().mean()) def test_solenoid(): diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index fa3e0a45..b293b6bc 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -44,12 +44,85 @@ def test_move_quadrupole_to_device(target_device: torch.device): assert quad.tilt.device.type == target_device.type +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Cavity, + cheetah.Dipole, + cheetah.Drift, + cheetah.HorizontalCorrector, + cheetah.Quadrupole, + cheetah.RBend, + cheetah.Solenoid, + cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, + ], +) +def test_forced_element_dtype(ElementClass): + """ + Test that the dtype is properly overridden for all element classes. + """ + element = ElementClass( + length=torch.tensor(1.0, dtype=torch.float64), dtype=torch.float16 + ) + + for buffer in element.buffers(): + assert buffer.dtype == torch.float16 + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Cavity, + cheetah.Dipole, + cheetah.Drift, + cheetah.HorizontalCorrector, + cheetah.Quadrupole, + cheetah.RBend, + cheetah.Solenoid, + cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, + ], +) +def test_infer_element_dtype(ElementClass): + """ + Test that the dtype is properly inferred for all element classes. + """ + element = ElementClass(length=torch.tensor(1.0, dtype=torch.float64)) + + for buffer in element.buffers(): + assert buffer.dtype == torch.float64 + + +def test_conflicting_quadrupole_dtype(): + """ + Test that creating a quadrupole with conflicting argument dtypes fails. + """ + with pytest.raises(AssertionError): + cheetah.Quadrupole( + length=torch.tensor(1.0, dtype=torch.float32), + k1=torch.tensor(10.0, dtype=torch.float64), + ) + + # Ensure that the conflict can be solved by explicit dtype selection + quad = cheetah.Quadrupole( + length=torch.tensor(1.0, dtype=torch.float32), + k1=torch.tensor(10.0, dtype=torch.float64), + dtype=torch.float16, + ) + assert quad.length.dtype == torch.float16 + + def test_change_quadrupole_dtype(): """ Test that a quadrupole magnet can be successfully changed to a different dtype. """ quad = cheetah.Quadrupole( - length=torch.tensor(0.2), k1=torch.tensor(4.2), name="my_quad" + length=torch.tensor(0.2), + k1=torch.tensor(4.2), + name="my_quad", ) # Test that by default the quadrupole is of dtype float32 @@ -102,6 +175,66 @@ def test_move_particlebeam_to_device(target_device: torch.device): assert beam.total_charge.device.type == target_device.type +@pytest.mark.parametrize( + "BeamClass", + [ + cheetah.ParameterBeam, + cheetah.ParticleBeam, + ], +) +def test_forced_beam_dtype(BeamClass): + """ + Test that the dtype is properly overriden on beam creation. + """ + beam = BeamClass.from_parameters( + mu_x=torch.tensor(1e-5, dtype=torch.float32), dtype=torch.float64 + ) + for buffer in beam.buffers(): + assert buffer.dtype == torch.float64 + + beam = BeamClass.from_twiss( + beta_x=torch.tensor(1.0, dtype=torch.float16), + beta_y=torch.tensor(2.0, dtype=torch.float64), + dtype=torch.float32, + ) + for buffer in beam.buffers(): + assert buffer.dtype == torch.float32 + + +@pytest.mark.parametrize( + "BeamClass", + [ + cheetah.ParameterBeam, + cheetah.ParticleBeam, + ], +) +def test_infer_beam_dtype(BeamClass): + """ + Test that the dtype is properly inferred on beam creation. + """ + beam = BeamClass.from_parameters(mu_x=torch.tensor(1e-5, dtype=torch.float64)) + for buffer in beam.buffers(): + assert buffer.dtype == torch.float64 + + beam = BeamClass.from_twiss( + beta_x=torch.tensor(1.0, dtype=torch.float64), + beta_y=torch.tensor(2.0, dtype=torch.float64), + ) + for buffer in beam.buffers(): + assert buffer.dtype == torch.float64 + + +def test_conflicting_particlebeam_dtype(): + """ + Test if creating a ParticleBeam with conflicting argument dtypes fails. + """ + with pytest.raises(AssertionError): + cheetah.ParticleBeam.from_twiss( + beta_x=torch.tensor(1.0, dtype=torch.float32), + beta_y=torch.tensor(2.0, dtype=torch.float64), + ) + + def test_change_particlebeam_dtype(): """ Test that a particle beam can be successfully changed to a different dtype. @@ -119,3 +252,27 @@ def test_change_particlebeam_dtype(): assert beam.particles.dtype == torch.float64 assert beam.energy.dtype == torch.float64 assert beam.total_charge.dtype == torch.float64 + + +@pytest.mark.parametrize( + "BeamClass", + [ + cheetah.ParameterBeam, + cheetah.ParticleBeam, + ], +) +def test_transformed_beam_dtype(BeamClass): + """ + Test that Beam.transformed_to keeps the dtype by default. + """ + beam = BeamClass.from_parameters(mu_x=torch.tensor(1e-5), dtype=torch.float64) + + # Verify the dtype is kept by default + transformed_beam = beam.transformed_to(mu_x=torch.tensor(-2e-5)) + assert transformed_beam.mu_x.dtype == torch.float64 + + # Check that the manual dtype selection works + transformed_beam = beam.transformed_to( + mu_x=torch.tensor(-2e-5), dtype=torch.float32 + ) + assert transformed_beam.mu_x.dtype == torch.float32 diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 08d016e4..231e7ddf 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -121,19 +121,20 @@ def test_dipole_bmadx_tracking(dtype): dtype ) - angle = torch.tensor([20 * torch.pi / 180], dtype=dtype) + # TODO: See if Bmad-X test dtypes can be cleaned up now that dtype PR was merged + angle = torch.tensor(20 * torch.pi / 180, dtype=dtype) e1 = angle / 2 e2 = angle - e1 dipole_cheetah_bmadx = Dipole( - length=torch.tensor([0.5]), + length=torch.tensor(0.5), angle=angle, e1=e1, e2=e2, - tilt=torch.tensor([0.1], dtype=dtype), - fringe_integral=torch.tensor([0.5]), - fringe_integral_exit=torch.tensor([0.5]), - gap=torch.tensor([0.05], dtype=dtype), - gap_exit=torch.tensor([0.05], dtype=dtype), + tilt=torch.tensor(0.1, dtype=dtype), + fringe_integral=torch.tensor(0.5), + fringe_integral_exit=torch.tensor(0.5), + gap=torch.tensor(0.05, dtype=dtype), + gap_exit=torch.tensor(0.05, dtype=dtype), fringe_at="both", fringe_type="linear_edge", tracking_method="bmadx", diff --git a/tests/test_drift.py b/tests/test_drift.py index f0115216..bb5ba4b8 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -71,7 +71,7 @@ def test_drift_bmadx_tracking(dtype): "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) drift = cheetah.Drift( - length=torch.tensor([1.0]), tracking_method="bmadx", dtype=dtype + length=torch.tensor(1.0), tracking_method="bmadx", dtype=dtype ) # Run tracking diff --git a/tests/test_elegant_conversion.py b/tests/test_elegant_conversion.py index 5e4ac94b..01fdfbaf 100644 --- a/tests/test_elegant_conversion.py +++ b/tests/test_elegant_conversion.py @@ -17,14 +17,14 @@ def test_fodo(): cheetah.Quadrupole( name="q1", length=torch.tensor(0.1), k1=torch.tensor(1.5) ), - cheetah.Drift(name="d1", length=torch.tensor(1)), + cheetah.Drift(name="d1", length=torch.tensor(1.0)), cheetah.Marker(name="m1"), cheetah.Dipole(name="s1", length=torch.tensor(0.3), e1=torch.tensor(0.25)), - cheetah.Drift(name="d1", length=torch.tensor(1)), + cheetah.Drift(name="d1", length=torch.tensor(1.0)), cheetah.Quadrupole( - name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3) + name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3.0) ), - cheetah.Drift(name="d2", length=torch.tensor(2)), + cheetah.Drift(name="d2", length=torch.tensor(2.0)), ], name="fodo", ) diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index f091ab1b..49003d8d 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -148,6 +148,19 @@ def test_generate_uniform_ellipsoid_vectorized(): assert torch.allclose(beam.total_charge, total_charge) +def test_only_sigma_vectorized(): + """ + Test that particle beam works correctly when only a vectorised sigma is given and + all else is scalar. + """ + beam = ParticleBeam.from_parameters( + num_particles=10_000, + mu_x=torch.tensor(1e-5), + sigma_x=torch.tensor([1.75e-7, 2.75e-7]), + ) + assert beam.particles.shape == (2, 10_000, 7) + + def test_indexing(): # test batching with beamline parameters quadrupole = cheetah.Quadrupole( diff --git a/tests/test_screen.py b/tests/test_screen.py index 0a1a43c3..9b6f243d 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -16,7 +16,7 @@ def test_reading_shows_beam_particle(screen_method): elements=[ cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( - resolution=torch.tensor((100, 100)), + resolution=(100, 100), pixel_size=torch.tensor((1e-5, 1e-5)), is_active=True, method=screen_method, @@ -46,12 +46,12 @@ def test_screen_kde_bandwidth(kde_bandwidth): elements=[ cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( - resolution=torch.tensor((100, 100)), + resolution=(100, 100), pixel_size=torch.tensor((1e-5, 1e-5)), is_active=True, method="kde", name="my_screen", - kde_bandwidth=kde_bandwidth, + kde_bandwidth=torch.tensor(kde_bandwidth), ), ], ) @@ -78,7 +78,7 @@ def test_reading_shows_beam_parameter(screen_method): elements=[ cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( - resolution=torch.tensor((100, 100)), + resolution=(100, 100), pixel_size=torch.tensor((1e-5, 1e-5)), is_active=True, method=screen_method, @@ -113,15 +113,13 @@ def test_reading_shows_beam_ares(screen_method): segment.AREABSCR1.method = screen_method - segment.AREABSCR1.resolution = torch.tensor( - (2448, 2040), device=segment.AREABSCR1.resolution.device - ) + segment.AREABSCR1.resolution = (2448, 2040) segment.AREABSCR1.pixel_size = torch.tensor( (3.3198e-6, 2.4469e-6), device=segment.AREABSCR1.pixel_size.device, dtype=segment.AREABSCR1.pixel_size.dtype, ) - segment.AREABSCR1.binning = torch.tensor(1, device=segment.AREABSCR1.binning.device) + segment.AREABSCR1.binning = 1 segment.AREABSCR1.is_active = True assert isinstance(segment.AREABSCR1.reading, torch.Tensor) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 8d91cb32..dad61da0 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -269,8 +269,54 @@ def test_space_charge_with_ares_astra_beam(): `IndexError: index -38 is out of bounds for dimension 3 with size 32`. """ segment = cheetah.Segment( - [cheetah.Drift(length=1.0), cheetah.SpaceChargeKick(effect_length=1.0)] + [ + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.SpaceChargeKick(effect_length=torch.tensor(1.0)), + ] ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") _ = segment.track(beam) + + +def test_space_charge_with_aperture_cutoff(): + """ + Tests that the space charge kick is correctly applied only to surviving particles, + by comparing the results with and without an aperture that results in beam losses. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.2)), + cheetah.Aperture( + x_max=torch.tensor(1e-4), + y_max=torch.tensor(1e-4), + shape="rectangular", + is_active="False", + name="aperture", + ), + cheetah.Drift(length=torch.tensor(0.25)), + cheetah.SpaceChargeKick(effect_length=torch.tensor(0.5)), + cheetah.Drift(length=torch.tensor(0.25)), + ] + ) + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + total_charge=torch.tensor(1e-9), + mu_x=torch.tensor(5e-5), + sigma_px=torch.tensor(1e-4), + sigma_py=torch.tensor(1e-4), + ) + + # Track with inactive aperture + outgoing_beam_without_aperture = segment.track(incoming_beam) + + # Activate the aperture and track the beam + segment.aperture.is_active = True + outgoing_beam_with_aperture = segment.track(incoming_beam) + + # Check that with particle loss the space charge kick is different + assert not torch.allclose( + outgoing_beam_with_aperture.particles, outgoing_beam_without_aperture.particles + ) + # Check that the number of surviving particles is less than the initial number + assert outgoing_beam_with_aperture.survival_probabilities.sum(dim=-1).max() < 10_000 diff --git a/tests/test_statistics.py b/tests/test_statistics.py new file mode 100644 index 00000000..3b603851 --- /dev/null +++ b/tests/test_statistics.py @@ -0,0 +1,73 @@ +import torch + +from cheetah.utils import unbiased_weighted_covariance, unbiased_weighted_variance + + +def test_unbiased_weighted_variance_with_single_element(): + """Test that the variance is NaN when there is only one element.""" + data = torch.tensor([42.0]) + weights = torch.tensor([1.0]) + + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.isnan(computed_variance) + + +def test_unbiased_weighted_variance_with_same_weights(): + """ + Test that the weighted variance with all weights the same equals the unweighted + variance implementated in PyTorch. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + + expected_variance = torch.var(data, unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_different_weights(): + """Test that the variance is computed when some weights are different.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 0.5, 0.5, 0, 0]) + + expected_variance = torch.var(torch.tensor([1.0, 2.0, 3.0]), unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_zero_weights(): + """Test that the variance is NaN when all weights are zero.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) + + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.isnan(computed_variance) + + +def test_unbiased_weighted_variance_with_small_numbers(): + """Test that the variance is correct for small numbers.""" + data = torch.tensor([1e-10, 2e-10, 3e-10, 4e-10, 5e-10], dtype=torch.float32) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32) + + expected_variance = torch.var(data, unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_covariance_reduced_to_variance(): + """ + Test that the covariance computation is correctly reduced to the variance when both + inputs are the same. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 1.0, 1.0, 0.9, 0.9]) + + variance = unbiased_weighted_variance(data, weights) + covariance = unbiased_weighted_covariance(data, data, weights) + + assert torch.allclose(covariance, variance) diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index 3555532b..a9662aa8 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -14,10 +14,10 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) tdc = cheetah.TransverseDeflectingCavity( - length=torch.tensor([1.0]), - voltage=torch.tensor([1e7]), - phase=torch.tensor([0.2], dtype=dtype), - frequency=torch.tensor([1e9]), + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor(0.2, dtype=dtype), + frequency=torch.tensor(1e9), tracking_method="bmadx", dtype=dtype, ) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 5fa3d362..aa00598a 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -309,7 +310,7 @@ def test_vectorized_screen_2d(BeamClass, method): elements=[ cheetah.Drift(length=torch.tensor(1.0)), cheetah.Screen( - resolution=torch.tensor((100, 100)), + resolution=(100, 100), pixel_size=torch.tensor((1e-5, 1e-5)), misalignment=torch.tensor( [ @@ -449,3 +450,47 @@ def test_broadcasting_solenoid_misalignment(): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) + + +@pytest.mark.parametrize("aperture_shape", ["rectangular", "elliptical"]) +def test_vectorized_aperture_broadcasting(aperture_shape): + """ + Test that apertures work in a vectorised setting and that broadcasting rules are + applied correctly. + """ + torch.manual_seed(0) + + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, + sigma_py=torch.tensor(1e-4), + sigma_px=torch.tensor(2e-4), + energy=torch.tensor([154e6, 14e9]), + ) + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.5)), + cheetah.Aperture( + x_max=torch.tensor([[1e-5], [2e-4], [3e-4]]), + y_max=torch.tensor(2e-4), + shape=aperture_shape, + ), + cheetah.Drift(length=torch.tensor(0.5)), + ] + ) + + outgoing = segment.track(incoming) + + # Particle positions are unaffected by the aperture ... only their survival is + assert outgoing.particles.shape == (2, 100_000, 7) + assert outgoing.energy.shape == (2,) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.survival_probabilities.shape == (3, 2, 100_000) + + if aperture_shape == "elliptical": + assert np.allclose( + outgoing.survival_probabilities.sum(dim=-1)[:, 0], [7672, 94523, 99547] + ) + elif aperture_shape == "rectangular": + assert np.allclose( + outgoing.survival_probabilities.sum(dim=-1)[:, 0], [7935, 95400, 99719] + )