From dada294f1fe5148b244d7c3dfa18d1883eced52d Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 3 Oct 2024 20:27:53 +0200 Subject: [PATCH 01/29] Add test to detect `Aperture` vectorisation issue --- tests/test_vectorized.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 87ebd03e..b58be54a 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -405,3 +405,28 @@ def test_vectorized_parameter_beam_creation(): assert torch.allclose(beam.mu_x, torch.tensor([2e-4, 3e-4])) assert beam.sigma_x.shape == (2,) assert torch.allclose(beam.sigma_x, torch.tensor([1e-5, 2e-5])) + + +def test_vectorized_aperture_broadcasting(): + """ + Test that apertures work in a vectorised setting and that broadcasting rules are + applied correctly. + """ + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=100_000, energy=torch.tensor([154e6, 14e9, 5e9]) + ) + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.5)), + cheetah.Aperture( + x_max=torch.tensor([[1e-3], [2e-3], [3e-3]]), y_max=torch.tensor(1e-3) + ), + cheetah.Drift(length=torch.tensor(0.5)), + ] + ) + + outgoing = segment.track(incoming) + + assert outgoing.particles.shape == (3, 2, 100_000, 7) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.energy.shape == (2,) From e5cc55d96d8f649c43363d5885b3e52911af5d45 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 3 Oct 2024 20:54:46 +0200 Subject: [PATCH 02/29] Minor fixes to the vectorised aperture test itself --- tests/test_vectorized.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index b58be54a..2229e224 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,5 +1,6 @@ import pytest import torch +from icecream import ic import cheetah @@ -413,7 +414,7 @@ def test_vectorized_aperture_broadcasting(): applied correctly. """ incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, energy=torch.tensor([154e6, 14e9, 5e9]) + num_particles=100_000, energy=torch.tensor([154e6, 14e9]) ) segment = cheetah.Segment( elements=[ @@ -425,6 +426,8 @@ def test_vectorized_aperture_broadcasting(): ] ) + ic(incoming.energy.shape, segment.elements[1].x_max.shape) + outgoing = segment.track(incoming) assert outgoing.particles.shape == (3, 2, 100_000, 7) From d9691ccff1cde0743cc6579635dd57c7e875dd87 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 3 Oct 2024 20:58:06 +0200 Subject: [PATCH 03/29] Extend test to cover both aperture shapes --- tests/test_vectorized.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 2229e224..af4eb865 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,6 +1,5 @@ import pytest import torch -from icecream import ic import cheetah @@ -408,7 +407,8 @@ def test_vectorized_parameter_beam_creation(): assert torch.allclose(beam.sigma_x, torch.tensor([1e-5, 2e-5])) -def test_vectorized_aperture_broadcasting(): +@pytest.mark.parametrize("shape", ["rectangular", "elliptical"]) +def test_vectorized_aperture_broadcasting(shape): """ Test that apertures work in a vectorised setting and that broadcasting rules are applied correctly. @@ -420,14 +420,14 @@ def test_vectorized_aperture_broadcasting(): elements=[ cheetah.Drift(length=torch.tensor(0.5)), cheetah.Aperture( - x_max=torch.tensor([[1e-3], [2e-3], [3e-3]]), y_max=torch.tensor(1e-3) + x_max=torch.tensor([[1e-3], [2e-3], [3e-3]]), + y_max=torch.tensor(1e-3), + shape=shape, ), cheetah.Drift(length=torch.tensor(0.5)), ] ) - ic(incoming.energy.shape, segment.elements[1].x_max.shape) - outgoing = segment.track(incoming) assert outgoing.particles.shape == (3, 2, 100_000, 7) From b175b788ac3f9e938d437bd2778247cadc1abe29 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 3 Oct 2024 21:11:07 +0200 Subject: [PATCH 04/29] Change logical accumulators to torch ones --- cheetah/accelerator/aperture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 2d1d9ad0..4ff496c7 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -72,7 +72,7 @@ def track(self, incoming: Beam) -> Beam: if not (isinstance(incoming, ParticleBeam) and self.is_active): return incoming - assert self.x_max >= 0 and self.y_max >= 0 + assert torch.all(self.x_max >= 0) and torch.all(self.y_max >= 0) assert self.shape in [ "rectangular", "elliptical", From 3f5d48fb86f8361e13f052c943378c5cb1640b2d Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Thu, 24 Oct 2024 15:50:53 +0200 Subject: [PATCH 05/29] Add `particle_survival` probability to the `ParticleBeam` --- cheetah/accelerator/aperture.py | 53 +++++++++++-------- cheetah/accelerator/cavity.py | 1 + cheetah/accelerator/dipole.py | 1 + cheetah/accelerator/drift.py | 1 + cheetah/accelerator/element.py | 1 + cheetah/accelerator/quadrupole.py | 1 + cheetah/accelerator/space_charge_kick.py | 18 ++++--- .../transverse_deflecting_cavity.py | 1 + cheetah/particles/particle_beam.py | 20 ++++++- tests/test_compare_ocelot.py | 10 +++- tests/test_dipole.py | 7 ++- tests/test_drift.py | 5 +- tests/test_quadrupole.py | 7 ++- tests/test_transverse_deflecting_cavity.py | 5 +- tests/test_vectorized.py | 25 +++++++-- 15 files changed, 116 insertions(+), 40 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 2181cad0..b9bc6228 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -15,6 +15,8 @@ class Aperture(Element): """ Physical aperture. + The aperture is only considered if tracking a `ParticleBeam` and the aperture is + active. :param x_max: half size horizontal offset in [m] :param y_max: half size vertical offset in [m] @@ -78,35 +80,44 @@ def track(self, incoming: Beam) -> Beam: "elliptical", ], f"Unknown aperture shape {self.shape}" + # broadcast x_max and y_max and the ParticleBeam to the same shape + vector_shape = torch.broadcast_shapes( + self.x_max.shape, + self.y_max.shape, + incoming.x.shape[:-1], + incoming.energy.shape, + ) + x_max = self.x_max.expand(vector_shape).unsqueeze(-1) + y_max = self.y_max.expand(vector_shape).unsqueeze(-1) + outgoing_particles = incoming.particles.expand(vector_shape + (-1, 7)) + if self.shape == "rectangular": survived_mask = torch.logical_and( - torch.logical_and(incoming.x > -self.x_max, incoming.x < self.x_max), - torch.logical_and(incoming.y > -self.y_max, incoming.y < self.y_max), + torch.logical_and(incoming.x > -x_max, incoming.x < x_max), + torch.logical_and(incoming.y > -y_max, incoming.y < y_max), ) elif self.shape == "elliptical": - survived_mask = ( - incoming.x**2 / self.x_max**2 + incoming.y**2 / self.y_max**2 - ) <= 1.0 - outgoing_particles = incoming.particles[survived_mask] + survived_mask = (incoming.x**2 / x_max**2 + incoming.y**2 / y_max**2) <= 1.0 - outgoing_particle_charges = incoming.particle_charges[survived_mask] + outgoing_survival = incoming.particle_survival * survived_mask - self.lost_particles = incoming.particles[torch.logical_not(survived_mask)] + # outgoing_particles = incoming.particles[survived_mask] - self.lost_particle_charges = incoming.particle_charges[ - torch.logical_not(survived_mask) - ] + # outgoing_particle_charges = incoming.particle_charges[survived_mask] - return ( - ParticleBeam( - outgoing_particles, - incoming.energy, - particle_charges=outgoing_particle_charges, - device=outgoing_particles.device, - dtype=outgoing_particles.dtype, - ) - if outgoing_particles.shape[0] > 0 - else ParticleBeam.empty + # self.lost_particles = incoming.particles[torch.logical_not(survived_mask)] + + # self.lost_particle_charges = incoming.particle_charges[ + # torch.logical_not(survived_mask) + # ] + + return ParticleBeam( + outgoing_particles, + incoming.energy, + particle_charges=incoming.particle_charges, + device=incoming.particles.device, + dtype=incoming.particles.dtype, + particle_survival=outgoing_survival, ) def split(self, resolution: torch.Tensor) -> list[Element]: diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index c7a89e05..f12d124e 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -243,6 +243,7 @@ def _track_beam(self, incoming: Beam) -> Beam: particle_charges=incoming.particle_charges, device=outgoing_particles.device, dtype=outgoing_particles.dtype, + particle_survival=incoming.particle_survival, ) return outgoing diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 5e919fbf..ba3ed97f 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -235,6 +235,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, + particle_survival=incoming.particle_survival, ) return outgoing_beam diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index eb5fb187..cac58a37 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -117,6 +117,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, + particle_survival=incoming.particle_survival, ) return outgoing_beam diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index bfe1df7c..76b99c3f 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -83,6 +83,7 @@ def track(self, incoming: Beam) -> Beam: particle_charges=incoming.particle_charges, device=new_particles.device, dtype=new_particles.dtype, + particle_survival=incoming.particle_survival, ) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index db6a559d..f2d94250 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -190,6 +190,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, + particle_survival=incoming.particle_survival, ) return outgoing_beam diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 7bb7d7f3..4e23cff3 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -563,6 +563,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges.unsqueeze(0), device=incoming.particles.device, dtype=incoming.particles.dtype, + particle_survival=incoming.particle_survival.unsqueeze(0), ) else: is_incoming_vectorized = True @@ -577,6 +578,9 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, + particle_survival=vectorized_incoming.particle_survival.flatten( + end_dim=-2 + ), ) flattened_length_effect = self.effect_length.flatten(end_dim=-1) @@ -614,9 +618,10 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: outgoing = ParticleBeam.from_xyz_pxpypz( xp_coordinates.squeeze(0), vectorized_incoming.energy.squeeze(0), - vectorized_incoming.particle_charges.squeeze(0), - vectorized_incoming.particles.device, - vectorized_incoming.particles.dtype, + particle_charges=vectorized_incoming.particle_charges.squeeze(0), + device=vectorized_incoming.particles.device, + dtype=vectorized_incoming.particles.dtype, + particle_survival=vectorized_incoming.particle_survival.squeeze(0), ) else: # Reverse the flattening of the vector dimensions @@ -625,9 +630,10 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: dim=0, sizes=vectorized_incoming.particles.shape[:-2] ), vectorized_incoming.energy, - vectorized_incoming.particle_charges, - vectorized_incoming.particles.device, - vectorized_incoming.particles.dtype, + particle_charges=vectorized_incoming.particle_charges, + particle_survival=vectorized_incoming.particle_survival, + device=vectorized_incoming.particles.device, + dtype=vectorized_incoming.particles.dtype, ) return outgoing else: diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index fd4ed2af..3e639919 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -207,6 +207,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), ref_energy, particle_charges=incoming.particle_charges, + particle_survival=incoming.particle_survival, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 5b83bab1..51e32f62 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -21,7 +21,9 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. - :param total_charge: Total charge of the beam in C. + :param particle_charges: Charges of macroparticles in the beam in C. + :param particle_survival: Survival probability of each particle in the beam. + Default to array with ones. (1: survive, 0: lost) :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ @@ -31,6 +33,7 @@ def __init__( particles: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, + particle_survival: Optional[torch.Tensor] = None, device=None, dtype=torch.float32, ) -> None: @@ -51,6 +54,15 @@ def __init__( ), ) self.register_buffer("energy", energy.to(**factory_kwargs)) + if particle_survival is not None: + # Try to broadcast the survival probability to the particles shape + particle_survival = particle_survival.expand(particles.shape[:-1]) + else: + # If no survival probability provided, default to all particles surviving + particle_survival = torch.ones(particles.shape[:-1], **factory_kwargs) + self.register_buffer( + "particle_survival", particle_survival.to(**factory_kwargs) + ) @classmethod def from_parameters( @@ -699,6 +711,7 @@ def from_xyz_pxpypz( xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, + particle_survival: Optional[torch.Tensor] = None, device=None, dtype=torch.float32, ) -> torch.Tensor: @@ -711,6 +724,7 @@ def from_xyz_pxpypz( particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, + particle_survival=particle_survival, device=device, dtype=dtype, ) @@ -782,6 +796,10 @@ def total_charge(self) -> torch.Tensor: def num_particles(self) -> int: return self.particles.shape[-2] + @property + def num_particles_survived(self) -> torch.Tensor: + return self.particle_survival.sum(dim=-1) + @property def x(self) -> Optional[torch.Tensor]: return self.particles[..., 0] if self is not Beam.empty else None diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 1478454e..2c6ba0df 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -176,7 +176,10 @@ def test_aperture(): navigator.activate_apertures() _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) - assert outgoing_beam.num_particles == outgoing_p_array.rparticles.shape[1] + assert ( + int(outgoing_beam.num_particles_survived) + == outgoing_p_array.rparticles.shape[1] + ) def test_aperture_elliptical(): @@ -215,7 +218,10 @@ def test_aperture_elliptical(): navigator.activate_apertures() _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) - assert outgoing_beam.num_particles == outgoing_p_array.rparticles.shape[1] + assert ( + int(outgoing_beam.num_particles_survived) + == outgoing_p_array.rparticles.shape[1] + ) def test_solenoid(): diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 08d016e4..56509fe0 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -117,8 +117,11 @@ def test_dipole_bmadx_tracking(dtype): Test that the results of tracking through a dipole with the `"bmadx"` tracking method match the results from Bmad-X. """ - incoming = torch.load("tests/resources/bmadx/incoming.pt", weights_only=False).to( - dtype + bmad_loaded = torch.load( + "tests/resources/bmadx/incoming.pt", weights_only=False + ).to(dtype) + incoming = ParticleBeam( + particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype ) angle = torch.tensor([20 * torch.pi / 180], dtype=dtype) diff --git a/tests/test_drift.py b/tests/test_drift.py index f0115216..cffb05b7 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -67,9 +67,12 @@ def test_drift_bmadx_tracking(dtype): Test that the results of tracking through a drift with the `"bmadx"` tracking method match the results from Bmad-X. """ - incoming_beam = torch.load( + bmad_loaded = torch.load( "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) + incoming_beam = cheetah.ParticleBeam( + particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype + ) drift = cheetah.Drift( length=torch.tensor([1.0]), tracking_method="bmadx", dtype=dtype ) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index dbdedac3..eab8097c 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -175,8 +175,11 @@ def test_quadrupole_bmadx_tracking(dtype): Test that the results of tracking through a quadrupole with the `"bmadx"` tracking method match the results from Bmad-X. """ - incoming = torch.load("tests/resources/bmadx/incoming.pt", weights_only=False).to( - dtype + bmad_loaded = torch.load( + "tests/resources/bmadx/incoming.pt", weights_only=False + ).to(dtype) + incoming = ParticleBeam( + particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype ) quadrupole = Quadrupole( length=torch.tensor(1.0), diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index 231c883d..7cae14b5 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -10,9 +10,12 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): Test that the results of tracking through a TDC with the `"bmadx"` tracking method match the results from Bmad-X. """ - incoming_beam = torch.load( + bmad_loaded = torch.load( "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) + incoming_beam = cheetah.ParticleBeam( + particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype + ) tdc = cheetah.TransverseDeflectingCavity( length=torch.tensor([1.0]), voltage=torch.tensor([1e7]), diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 31e11b3e..ab029125 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -450,21 +450,27 @@ def test_broadcasting_solenoid_misalignment(): assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) + @pytest.mark.parametrize("shape", ["rectangular", "elliptical"]) def test_vectorized_aperture_broadcasting(shape): """ Test that apertures work in a vectorised setting and that broadcasting rules are applied correctly. """ + torch.manual_seed(0) incoming = cheetah.ParticleBeam.from_parameters( - num_particles=100_000, energy=torch.tensor([154e6, 14e9]) + num_particles=100_000, + # mu_x=torch.tensor(1e-4), + sigma_py=torch.tensor(1e-4), + sigma_px=torch.tensor(2e-4), + energy=torch.tensor([154e6, 14e9]), ) segment = cheetah.Segment( elements=[ cheetah.Drift(length=torch.tensor(0.5)), cheetah.Aperture( - x_max=torch.tensor([[1e-3], [2e-3], [3e-3]]), - y_max=torch.tensor(1e-3), + x_max=torch.tensor([[1e-5], [2e-4], [3e-4]]), + y_max=torch.tensor(2e-4), shape=shape, ), cheetah.Drift(length=torch.tensor(0.5)), @@ -475,4 +481,15 @@ def test_vectorized_aperture_broadcasting(shape): assert outgoing.particles.shape == (3, 2, 100_000, 7) assert outgoing.particle_charges.shape == (100_000,) - assert outgoing.energy.shape == (2,) \ No newline at end of file + assert outgoing.energy.shape == (2,) + + if shape == "elliptical": + assert torch.allclose( + outgoing.particle_survival.sum(dim=-1)[:, 0], + torch.tensor([7672, 94523, 99547], dtype=outgoing.particle_survival.dtype), + ) + elif shape == "rectangular": + assert torch.allclose( + outgoing.particle_survival.sum(dim=-1)[:, 0], + torch.tensor([7935, 95400, 99719], dtype=outgoing.particle_survival.dtype), + ) From 247b64c0019b6aa7d501b69fe7160c87869bcac9 Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Thu, 24 Oct 2024 17:46:39 +0200 Subject: [PATCH 06/29] Include particle survival into the statistical beam parameters calculations --- cheetah/particles/particle_beam.py | 118 ++++++++++++++++++++++++----- cheetah/utils/__init__.py | 5 ++ cheetah/utils/statistics.py | 64 ++++++++++++++++ tests/test_compare_ocelot.py | 9 +++ 4 files changed, 177 insertions(+), 19 deletions(-) create mode 100644 cheetah/utils/statistics.py diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 51e32f62..384be6c5 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -6,7 +6,11 @@ from torch.distributions import MultivariateNormal from cheetah.particles.beam import Beam -from cheetah.utils import elementwise_linspace +from cheetah.utils import ( + elementwise_linspace, + unbiased_weighted_covariance, + unbiased_weighted_std, +) speed_of_light = torch.tensor(constants.speed_of_light) # In m/s electron_mass = torch.tensor(constants.electron_mass) # In kg @@ -790,14 +794,24 @@ def __len__(self) -> int: @property def total_charge(self) -> torch.Tensor: + """Returns the total charge of the beam in C. + Note: it does not take into account the survival of the particles. + """ return torch.sum(self.particle_charges, dim=-1) + @property + def total_charge_survived(self) -> torch.Tensor: + """Returns the total charge of the survived macroparticles, in C.""" + return torch.sum(self.particle_charges * self.particle_survival, dim=-1) + @property def num_particles(self) -> int: + """Length of the macroparticle array, does not account for lost.""" return self.particles.shape[-2] @property def num_particles_survived(self) -> torch.Tensor: + """Returns the number of macroparticles that survived the simulation.""" return self.particle_survival.sum(dim=-1) @property @@ -810,11 +824,24 @@ def x(self, value: torch.Tensor) -> None: @property def mu_x(self) -> Optional[torch.Tensor]: - return self.x.mean(dim=-1) if self is not Beam.empty else None + """Mean of the :math:`x` coordinates of the particles, weighted by their + survival probability.""" + return ( + torch.sum((self.x * self.particle_survival), dim=-1) + / self.num_particles_survived + if self is not Beam.empty + else None + ) @property def sigma_x(self) -> Optional[torch.Tensor]: - return self.x.std(dim=-1) if self is not Beam.empty else None + """Standard deviation of the :math:`x` coordinates of the particles, weighted + by their survival probability.""" + return ( + unbiased_weighted_std(self.x, self.particle_survival, dim=-1) + if self is not Beam.empty + else None + ) @property def px(self) -> Optional[torch.Tensor]: @@ -826,11 +853,24 @@ def px(self, value: torch.Tensor) -> None: @property def mu_px(self) -> Optional[torch.Tensor]: - return self.px.mean(dim=-1) if self is not Beam.empty else None + """Mean of the :math:`px` coordinates of the particles, weighted by their + survival probability.""" + return ( + torch.sum((self.px * self.particle_survival), dim=-1) + / self.num_particles_survived + if self is not Beam.empty + else None + ) @property def sigma_px(self) -> Optional[torch.Tensor]: - return self.px.std(dim=-1) if self is not Beam.empty else None + """Standard deviation of the :math:`px` coordinates of the particles, weighted + by their survival probability.""" + return ( + unbiased_weighted_std(self.px, self.particle_survival, dim=-1) + if self is not Beam.empty + else None + ) @property def y(self) -> Optional[torch.Tensor]: @@ -842,11 +882,20 @@ def y(self, value: torch.Tensor) -> None: @property def mu_y(self) -> Optional[float]: - return self.y.mean(dim=-1) if self is not Beam.empty else None + return ( + torch.sum((self.y * self.particle_survival), dim=-1) + / self.num_particles_survived + if self is not Beam.empty + else None + ) @property def sigma_y(self) -> Optional[torch.Tensor]: - return self.y.std(dim=-1) if self is not Beam.empty else None + return ( + unbiased_weighted_std(self.y, self.particle_survival, dim=-1) + if self is not Beam.empty + else None + ) @property def py(self) -> Optional[torch.Tensor]: @@ -858,11 +907,20 @@ def py(self, value: torch.Tensor) -> None: @property def mu_py(self) -> Optional[torch.Tensor]: - return self.py.mean(dim=-1) if self is not Beam.empty else None + return ( + torch.sum((self.py * self.particle_survival), dim=-1) + / self.num_particles_survived + if self is not Beam.empty + else None + ) @property def sigma_py(self) -> Optional[torch.Tensor]: - return self.py.std(dim=-1) if self is not Beam.empty else None + return ( + unbiased_weighted_std(self.py, self.particle_survival, dim=-1) + if self is not Beam.empty + else None + ) @property def tau(self) -> Optional[torch.Tensor]: @@ -874,11 +932,20 @@ def tau(self, value: torch.Tensor) -> None: @property def mu_tau(self) -> Optional[torch.Tensor]: - return self.tau.mean(dim=-1) if self is not Beam.empty else None + return ( + torch.sum((self.tau * self.particle_survival), dim=-1) + / self.num_particles_survived + if self is not Beam.empty + else None + ) @property def sigma_tau(self) -> Optional[torch.Tensor]: - return self.tau.std(dim=-1) if self is not Beam.empty else None + return ( + unbiased_weighted_std(self.tau, self.particle_survival, dim=-1) + if self is not Beam.empty + else None + ) @property def p(self) -> Optional[torch.Tensor]: @@ -890,24 +957,37 @@ def p(self, value: torch.Tensor) -> None: @property def mu_p(self) -> Optional[torch.Tensor]: - return self.p.mean(dim=-1) if self is not Beam.empty else None + return ( + torch.sum((self.p * self.particle_survival), dim=-1) + / self.num_particles_survived + if self is not Beam.empty + else None + ) @property def sigma_p(self) -> Optional[torch.Tensor]: - return self.p.std(dim=-1) if self is not Beam.empty else None + return ( + unbiased_weighted_std(self.p, self.particle_survival, dim=-1) + if self is not Beam.empty + else None + ) @property def sigma_xpx(self) -> torch.Tensor: - return torch.mean( - (self.x - self.mu_x.unsqueeze(-1)) * (self.px - self.mu_px.unsqueeze(-1)), - dim=-1, + r"""Returns the covariance between x and px. :math:`\sigma_{x, px}^2`. + It is weighted by the survival probability of the particles. + """ + return unbiased_weighted_covariance( + self.x, self.px, self.particle_survival, dim=-1 ) @property def sigma_ypy(self) -> torch.Tensor: - return torch.mean( - (self.y - self.mu_y.unsqueeze(-1)) * (self.py - self.mu_py.unsqueeze(-1)), - dim=-1, + r"""Returns the covariance between y and py. :math:`\sigma_{y, py}^2`. + It is weighted by the survival probability of the particles. + """ + return unbiased_weighted_covariance( + self.y, self.py, self.particle_survival, dim=-1 ) @property diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index a29d9ae9..9073c158 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -3,4 +3,9 @@ from .elementwise_linspace import elementwise_linspace # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 from .physics import compute_relativistic_factors # noqa: F401 +from .statistics import ( # noqa: F401 + unbiased_weighted_covariance, + unbiased_weighted_std, + unbiased_weighted_variance, +) from .unique_name_generator import UniqueNameGenerator # noqa: F401 diff --git a/cheetah/utils/statistics.py b/cheetah/utils/statistics.py new file mode 100644 index 00000000..8e5a3594 --- /dev/null +++ b/cheetah/utils/statistics.py @@ -0,0 +1,64 @@ +import torch + + +def unbiased_weighted_covariance( + input1: torch.Tensor, input2: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """Compute the unbiased weighted covariance of two tensors. + inputs and weights should be broadcastable. + + :param input1: Input tensor 1. (batch_size, sample_size) + :param input2: Input tensor 2. (batch_size, sample_size) + :param weights: Weights tensor. (batch_size, sample_size) + :param dim: Dimension along which to compute the covariance. + :return: Unbiased weighted covariance. (batch_size, 2, 2) + """ + weighted_mean1 = torch.sum(input1 * weights, dim=dim) / torch.sum(weights, dim=dim) + weighted_mean2 = torch.sum(input2 * weights, dim=dim) / torch.sum(weights, dim=dim) + correction_factor = torch.sum(weights, dim=dim) - torch.sum( + weights**2, dim=dim + ) / torch.sum(weights, dim=dim) + covariance = torch.sum( + weights + * (input1 - weighted_mean1.unsqueeze(-1)) + * (input2 - weighted_mean2.unsqueeze(-1)), + dim=dim, + ) / (correction_factor) + return covariance + + +def unbiased_weighted_variance( + input: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted variance of a tensor. + inputs and weights should be broadcastable. + + :param input: Input tensor. + :param weights: Weights tensor. + :param dim: Dimension along which to compute the variance. + :return: Unbiased weighted variance. + """ + weighted_mean = torch.sum(input * weights, dim=dim) / torch.sum(weights, dim=dim) + correction_factor = torch.sum(weights, dim=dim) - torch.sum( + weights**2, dim=dim + ) / torch.sum(weights, dim=dim) + variance = torch.sum( + weights * (input - weighted_mean.unsqueeze(-1)) ** 2, dim=dim + ) / (correction_factor) + return variance + + +def unbiased_weighted_std( + input: torch.Tensor, weights: torch.Tensor, dim: int = None +) -> torch.Tensor: + """ + Compute the unbiased weighted standard deviation of a tensor. + inputs and weights should be broadcastable. + + :param input: Input tensor. + :param weights: Weights tensor. + :param dim: Dimension along which to compute the standard deviation. + :return: Unbiased weighted standard deviation. + """ + return torch.sqrt(unbiased_weighted_variance(input, weights, dim=dim)) diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 2c6ba0df..beb3bfa4 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -223,6 +223,15 @@ def test_aperture_elliptical(): == outgoing_p_array.rparticles.shape[1] ) + assert np.allclose( + outgoing_beam.mu_x.cpu().numpy(), + outgoing_p_array.x().mean(), + ) + assert np.allclose( + outgoing_beam.mu_px.cpu().numpy(), + outgoing_p_array.px().mean(), + ) + def test_solenoid(): """ From 5a79ed35ac09293153ad95ebab117431147f3f26 Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Fri, 25 Oct 2024 15:04:19 +0200 Subject: [PATCH 07/29] Add tests for variance calculation --- tests/test_statistics_calculation.py | 61 ++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/test_statistics_calculation.py diff --git a/tests/test_statistics_calculation.py b/tests/test_statistics_calculation.py new file mode 100644 index 00000000..c7268a59 --- /dev/null +++ b/tests/test_statistics_calculation.py @@ -0,0 +1,61 @@ +import torch + +from cheetah.utils import unbiased_weighted_covariance, unbiased_weighted_variance + + +def test_unbiased_weighted_variance_with_same_weights(): + """Test that the variance is calculated correctly with equal weights.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + expected_variance = torch.var(data, unbiased=True) + calculated_variance = unbiased_weighted_variance(data, weights) + assert torch.allclose(calculated_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_single_element(): + """Test that the variance is nan when there is only one element.""" + data = torch.tensor([42.0]) + weights = torch.tensor([1.0]) + assert torch.isnan(unbiased_weighted_variance(data, weights)) + + +def test_unbiased_weighted_variance_with_different_weights(): + """Test that the variance is calculated correctly with different weights.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 0.5, 0.5, 0, 0]) + expected_variance = torch.var(torch.tensor([1.0, 2.0, 3.0]), unbiased=True) + calculated_variance = unbiased_weighted_variance(data, weights) + assert torch.allclose(calculated_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_zero_weights(): + """Test that the variance is nan when all weights are zero.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) + assert torch.isnan(unbiased_weighted_variance(data, weights)) + + +def test_unbiased_weighted_variance_with_small_numbers(): + """Test that the variance is calculated correctly with small numbers.""" + data = torch.tensor([1e-10, 2e-10, 3e-10, 4e-10, 5e-10], dtype=torch.float32) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32) + expected_variance = torch.var(data, unbiased=True) + calculated_variance = unbiased_weighted_variance(data, weights) + assert torch.allclose(calculated_variance, expected_variance) + + +def test_unbiased_weighted_covariance_reduced_to_variance(): + """Test that the covariance calculation is reduced to the variance when both inputs + are the same. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + equal_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + expected_variance = torch.var(data, unbiased=True) + calculated_covariance = unbiased_weighted_covariance(data, data, equal_weights) + assert torch.allclose(calculated_covariance, expected_variance) + + different_weights = torch.tensor([0.5, 1.0, 1.0, 0.9, 0.9]) + assert torch.allclose( + unbiased_weighted_covariance(data, data, different_weights), + unbiased_weighted_variance(data, different_weights), + ) From 0d7d07d2a083a900f0b37041417840c6ebfe28f0 Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Fri, 25 Oct 2024 15:33:41 +0200 Subject: [PATCH 08/29] Account for the lost particles in space charge effects --- cheetah/accelerator/space_charge_kick.py | 3 +- tests/test_space_charge_kick.py | 39 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 4e23cff3..5fa6e226 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -151,7 +151,8 @@ def _deposit_charge_on_grid( ) # Accumulate the charge contributions - repeated_charges = beam.particle_charges.repeat_interleave( + survived_particle_charges = beam.particle_charges * beam.particle_survival + repeated_charges = survived_particle_charges.repeat_interleave( repeats=8, dim=-1 ) # Shape:(..., 8 * num_particles) values = (cell_weights.flatten(start_dim=-2) * repeated_charges)[valid_mask] diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 8d91cb32..28bb8234 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -274,3 +274,42 @@ def test_space_charge_with_ares_astra_beam(): beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") _ = segment.track(beam) + + +def test_space_charge_with_aperture_cutoff(): + """ + Tests that the space charge kick is correctly calculated only for the surviving + particles when an aperture is used. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + total_charge=torch.tensor(1e-9), + mu_x=torch.tensor(5e-5), + sigma_px=torch.tensor(1e-4), + sigma_py=torch.tensor(1e-4), + ) + + drift1 = cheetah.Drift(length=torch.tensor(0.2)) + aperture = cheetah.Aperture( + x_max=torch.tensor(1e-4), + y_max=torch.tensor(1e-4), + shape="rectangular", + is_active="False", + ) + drift2 = cheetah.Drift(length=torch.tensor(0.25)) + space_charge = cheetah.SpaceChargeKick(effect_length=torch.tensor(0.5)) + segment = cheetah.Segment(elements=[drift1, aperture, drift2, space_charge, drift2]) + + outgoing_beam_without_aperture = segment.track(incoming_beam) + + # Activate the aperture + aperture.is_active = True + outgoing_beam_with_aperture = segment.track(incoming_beam) + + # Check that with particle loss the space charge kick is different + assert not torch.allclose( + outgoing_beam_with_aperture.particles, outgoing_beam_without_aperture.particles + ) + + # Check that the number of surviving particles is less than the initial number + assert outgoing_beam_with_aperture.particle_survival.sum(dim=-1).max() < 10_000 From c8a1a023d9b2c30f4fa30084ea9727505d61a06d Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Fri, 25 Oct 2024 15:37:18 +0200 Subject: [PATCH 09/29] Update `CHANGELOG.md` --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65c9c24d..e18041f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265) (@jank324, @cr-xu, @hespe, @roussel-ryan) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) +- Rework the `Aperture` element. Now `ParticleBeam` has a `particle_survival` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu) ### 🚀 Features From 0cd8edac756b6afae2f6bf30baa16c0b693fc509 Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Fri, 25 Oct 2024 15:45:54 +0200 Subject: [PATCH 10/29] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e18041f6..062bace3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265) (@jank324, @cr-xu, @hespe, @roussel-ryan) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) -- Rework the `Aperture` element. Now `ParticleBeam` has a `particle_survival` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu) +- Rework the `Aperture` element. Now `ParticleBeam` has a `particle_survival` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu, @jank324) ### 🚀 Features From 157f6dd917aae4e96fe3ac485d3022d47758ecac Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 16:39:13 +0100 Subject: [PATCH 11/29] More descriptive property name --- CHANGELOG.md | 2 +- cheetah/accelerator/aperture.py | 4 +- cheetah/accelerator/cavity.py | 2 +- cheetah/accelerator/dipole.py | 2 +- cheetah/accelerator/drift.py | 2 +- cheetah/accelerator/element.py | 2 +- cheetah/accelerator/quadrupole.py | 2 +- cheetah/accelerator/space_charge_kick.py | 12 +++-- .../transverse_deflecting_cavity.py | 2 +- cheetah/particles/particle_beam.py | 48 +++++++++---------- tests/test_space_charge_kick.py | 2 +- tests/test_vectorized.py | 12 +++-- 12 files changed, 49 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d52d220..77afef54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) - The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) -- Rework the `Aperture` element. Now `ParticleBeam` has a `particle_survival` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu, @jank324) +- Rework the `Aperture` element. Now `ParticleBeam` has a `survived_probabilites` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu, @jank324) ### 🚀 Features diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index a9b3621f..a90fdbac 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -99,7 +99,7 @@ def track(self, incoming: Beam) -> Beam: elif self.shape == "elliptical": survived_mask = (incoming.x**2 / x_max**2 + incoming.y**2 / y_max**2) <= 1.0 - outgoing_survival = incoming.particle_survival * survived_mask + outgoing_survival = incoming.survived_probabilities * survived_mask # outgoing_particles = incoming.particles[survived_mask] @@ -117,7 +117,7 @@ def track(self, incoming: Beam) -> Beam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, - particle_survival=outgoing_survival, + survived_probabilities=outgoing_survival, ) def split(self, resolution: torch.Tensor) -> list[Element]: diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index bdd32aa7..02cc4c23 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -249,7 +249,7 @@ def _track_beam(self, incoming: Beam) -> Beam: particle_charges=incoming.particle_charges, device=outgoing_particles.device, dtype=outgoing_particles.dtype, - particle_survival=incoming.particle_survival, + survived_probabilities=incoming.survived_probabilities, ) return outgoing diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 43922cd6..cb7e1133 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -262,7 +262,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, - particle_survival=incoming.particle_survival, + survived_probabilities=incoming.survived_probabilities, ) return outgoing_beam diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 5773eef3..1fd72aae 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -116,7 +116,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, - particle_survival=incoming.particle_survival, + survived_probabilities=incoming.survived_probabilities, ) return outgoing_beam diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 76b99c3f..dbbf9a3a 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -83,7 +83,7 @@ def track(self, incoming: Beam) -> Beam: particle_charges=incoming.particle_charges, device=new_particles.device, dtype=new_particles.dtype, - particle_survival=incoming.particle_survival, + survived_probabilities=incoming.survived_probabilities, ) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 30da0f5b..dc546c65 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -192,7 +192,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, - particle_survival=incoming.particle_survival, + survived_probabilities=incoming.survived_probabilities, ) return outgoing_beam diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 4cbb7e88..34437530 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -149,7 +149,7 @@ def _deposit_charge_on_grid( ) # Accumulate the charge contributions - survived_particle_charges = beam.particle_charges * beam.particle_survival + survived_particle_charges = beam.particle_charges * beam.survived_probabilities repeated_charges = survived_particle_charges.repeat_interleave( repeats=8, dim=-1 ) # Shape:(..., 8 * num_particles) @@ -562,7 +562,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=incoming.particle_charges.unsqueeze(0), device=incoming.particles.device, dtype=incoming.particles.dtype, - particle_survival=incoming.particle_survival.unsqueeze(0), + survived_probabilities=incoming.survived_probabilities.unsqueeze(0), ) else: is_incoming_vectorized = True @@ -577,7 +577,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, - particle_survival=vectorized_incoming.particle_survival.flatten( + survived_probabilities=vectorized_incoming.survived_probabilities.flatten( end_dim=-2 ), ) @@ -624,7 +624,9 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=vectorized_incoming.particle_charges.squeeze(0), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, - particle_survival=vectorized_incoming.particle_survival.squeeze(0), + survived_probabilities=vectorized_incoming.survived_probabilities.squeeze( + 0 + ), ) else: # Reverse the flattening of the vector dimensions @@ -634,7 +636,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ), vectorized_incoming.energy, particle_charges=vectorized_incoming.particle_charges, - particle_survival=vectorized_incoming.particle_survival, + survived_probabilities=vectorized_incoming.survived_probabilities, device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, ) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index ceedf51b..0d621d93 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -210,7 +210,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), ref_energy, particle_charges=incoming.particle_charges, - particle_survival=incoming.particle_survival, + survived_probabilities=incoming.survived_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 2a93b4ed..227ed410 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -27,7 +27,7 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. :param particle_charges: Charges of macroparticles in the beam in C. - :param particle_survival: Survival probability of each particle in the beam. + :param survived_probabilities: Survival probability of each particle in the beam. Default to array with ones. (1: survive, 0: lost) :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. @@ -39,7 +39,7 @@ def __init__( particles: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, - particle_survival: Optional[torch.Tensor] = None, + survived_probabilities: Optional[torch.Tensor] = None, device=None, dtype=None, ) -> None: @@ -63,14 +63,14 @@ def __init__( ), ) self.register_buffer("energy", energy.to(**factory_kwargs)) - if particle_survival is not None: + if survived_probabilities is not None: # Try to broadcast the survival probability to the particles shape - particle_survival = particle_survival.expand(particles.shape[:-1]) + survived_probabilities = survived_probabilities.expand(particles.shape[:-1]) else: # If no survival probability provided, default to all particles surviving - particle_survival = torch.ones(particles.shape[:-1], **factory_kwargs) + survived_probabilities = torch.ones(particles.shape[:-1], **factory_kwargs) self.register_buffer( - "particle_survival", particle_survival.to(**factory_kwargs) + "survived_probabilities", survived_probabilities.to(**factory_kwargs) ) @classmethod @@ -812,7 +812,7 @@ def from_xyz_pxpypz( xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, - particle_survival: Optional[torch.Tensor] = None, + survived_probabilities: Optional[torch.Tensor] = None, device=None, dtype=torch.float32, ) -> torch.Tensor: @@ -825,7 +825,7 @@ def from_xyz_pxpypz( particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, - particle_survival=particle_survival, + survived_probabilities=survived_probabilities, device=device, dtype=dtype, ) @@ -899,7 +899,7 @@ def total_charge(self) -> torch.Tensor: @property def total_charge_survived(self) -> torch.Tensor: """Returns the total charge of the survived macroparticles, in C.""" - return torch.sum(self.particle_charges * self.particle_survival, dim=-1) + return torch.sum(self.particle_charges * self.survived_probabilities, dim=-1) @property def num_particles(self) -> int: @@ -909,7 +909,7 @@ def num_particles(self) -> int: @property def num_particles_survived(self) -> torch.Tensor: """Returns the number of macroparticles that survived the simulation.""" - return self.particle_survival.sum(dim=-1) + return self.survived_probabilities.sum(dim=-1) @property def x(self) -> Optional[torch.Tensor]: @@ -924,7 +924,7 @@ def mu_x(self) -> Optional[torch.Tensor]: """Mean of the :math:`x` coordinates of the particles, weighted by their survival probability.""" return ( - torch.sum((self.x * self.particle_survival), dim=-1) + torch.sum((self.x * self.survived_probabilities), dim=-1) / self.num_particles_survived if self is not Beam.empty else None @@ -935,7 +935,7 @@ def sigma_x(self) -> Optional[torch.Tensor]: """Standard deviation of the :math:`x` coordinates of the particles, weighted by their survival probability.""" return ( - unbiased_weighted_std(self.x, self.particle_survival, dim=-1) + unbiased_weighted_std(self.x, self.survived_probabilities, dim=-1) if self is not Beam.empty else None ) @@ -953,7 +953,7 @@ def mu_px(self) -> Optional[torch.Tensor]: """Mean of the :math:`px` coordinates of the particles, weighted by their survival probability.""" return ( - torch.sum((self.px * self.particle_survival), dim=-1) + torch.sum((self.px * self.survived_probabilities), dim=-1) / self.num_particles_survived if self is not Beam.empty else None @@ -964,7 +964,7 @@ def sigma_px(self) -> Optional[torch.Tensor]: """Standard deviation of the :math:`px` coordinates of the particles, weighted by their survival probability.""" return ( - unbiased_weighted_std(self.px, self.particle_survival, dim=-1) + unbiased_weighted_std(self.px, self.survived_probabilities, dim=-1) if self is not Beam.empty else None ) @@ -980,7 +980,7 @@ def y(self, value: torch.Tensor) -> None: @property def mu_y(self) -> Optional[float]: return ( - torch.sum((self.y * self.particle_survival), dim=-1) + torch.sum((self.y * self.survived_probabilities), dim=-1) / self.num_particles_survived if self is not Beam.empty else None @@ -989,7 +989,7 @@ def mu_y(self) -> Optional[float]: @property def sigma_y(self) -> Optional[torch.Tensor]: return ( - unbiased_weighted_std(self.y, self.particle_survival, dim=-1) + unbiased_weighted_std(self.y, self.survived_probabilities, dim=-1) if self is not Beam.empty else None ) @@ -1005,7 +1005,7 @@ def py(self, value: torch.Tensor) -> None: @property def mu_py(self) -> Optional[torch.Tensor]: return ( - torch.sum((self.py * self.particle_survival), dim=-1) + torch.sum((self.py * self.survived_probabilities), dim=-1) / self.num_particles_survived if self is not Beam.empty else None @@ -1014,7 +1014,7 @@ def mu_py(self) -> Optional[torch.Tensor]: @property def sigma_py(self) -> Optional[torch.Tensor]: return ( - unbiased_weighted_std(self.py, self.particle_survival, dim=-1) + unbiased_weighted_std(self.py, self.survived_probabilities, dim=-1) if self is not Beam.empty else None ) @@ -1030,7 +1030,7 @@ def tau(self, value: torch.Tensor) -> None: @property def mu_tau(self) -> Optional[torch.Tensor]: return ( - torch.sum((self.tau * self.particle_survival), dim=-1) + torch.sum((self.tau * self.survived_probabilities), dim=-1) / self.num_particles_survived if self is not Beam.empty else None @@ -1039,7 +1039,7 @@ def mu_tau(self) -> Optional[torch.Tensor]: @property def sigma_tau(self) -> Optional[torch.Tensor]: return ( - unbiased_weighted_std(self.tau, self.particle_survival, dim=-1) + unbiased_weighted_std(self.tau, self.survived_probabilities, dim=-1) if self is not Beam.empty else None ) @@ -1055,7 +1055,7 @@ def p(self, value: torch.Tensor) -> None: @property def mu_p(self) -> Optional[torch.Tensor]: return ( - torch.sum((self.p * self.particle_survival), dim=-1) + torch.sum((self.p * self.survived_probabilities), dim=-1) / self.num_particles_survived if self is not Beam.empty else None @@ -1064,7 +1064,7 @@ def mu_p(self) -> Optional[torch.Tensor]: @property def sigma_p(self) -> Optional[torch.Tensor]: return ( - unbiased_weighted_std(self.p, self.particle_survival, dim=-1) + unbiased_weighted_std(self.p, self.survived_probabilities, dim=-1) if self is not Beam.empty else None ) @@ -1075,7 +1075,7 @@ def sigma_xpx(self) -> torch.Tensor: It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( - self.x, self.px, self.particle_survival, dim=-1 + self.x, self.px, self.survived_probabilities, dim=-1 ) @property @@ -1084,7 +1084,7 @@ def sigma_ypy(self) -> torch.Tensor: It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( - self.y, self.py, self.particle_survival, dim=-1 + self.y, self.py, self.survived_probabilities, dim=-1 ) @property diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 36f65edf..6faee89f 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -315,4 +315,4 @@ def test_space_charge_with_aperture_cutoff(): ) # Check that the number of surviving particles is less than the initial number - assert outgoing_beam_with_aperture.particle_survival.sum(dim=-1).max() < 10_000 + assert outgoing_beam_with_aperture.survived_probabilities.sum(dim=-1).max() < 10_000 diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 23f2f277..fa82f908 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -485,11 +485,15 @@ def test_vectorized_aperture_broadcasting(shape): if shape == "elliptical": assert torch.allclose( - outgoing.particle_survival.sum(dim=-1)[:, 0], - torch.tensor([7672, 94523, 99547], dtype=outgoing.particle_survival.dtype), + outgoing.survived_probabilities.sum(dim=-1)[:, 0], + torch.tensor( + [7672, 94523, 99547], dtype=outgoing.survived_probabilities.dtype + ), ) elif shape == "rectangular": assert torch.allclose( - outgoing.particle_survival.sum(dim=-1)[:, 0], - torch.tensor([7935, 95400, 99719], dtype=outgoing.particle_survival.dtype), + outgoing.survived_probabilities.sum(dim=-1)[:, 0], + torch.tensor( + [7935, 95400, 99719], dtype=outgoing.survived_probabilities.dtype + ), ) From 0172baa77a3d7a1e870458f5c53ac7cef817c8c7 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 18:31:41 +0100 Subject: [PATCH 12/29] Align the way survival probabilities are instatiated with other properties in Cheetah --- cheetah/accelerator/space_charge_kick.py | 87 ++++++++++++------------ cheetah/particles/particle_beam.py | 15 ++-- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 34437530..84c29066 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -552,22 +552,30 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. - # Make sure that the incoming beam has at least one vector dimension - if len(incoming.particles.shape) == 2: - is_incoming_vectorized = False - - vectorized_incoming = ParticleBeam( - particles=incoming.particles.unsqueeze(0), - energy=incoming.energy.unsqueeze(0), - particle_charges=incoming.particle_charges.unsqueeze(0), - device=incoming.particles.device, - dtype=incoming.particles.dtype, - survived_probabilities=incoming.survived_probabilities.unsqueeze(0), - ) - else: - is_incoming_vectorized = True - - vectorized_incoming = incoming + # Make sure that the incoming beam has at least one vector dimension by + # broadcasting with a dummy dimension (1,). + vector_shape = torch.broadcast_shapes( + incoming.particles.shape[:-2], + incoming.energy.shape, + incoming.particle_charges.shape[:-1], + incoming.survived_probabilities.shape[:-1], + (1,), + ) + vectorized_incoming = ParticleBeam( + particles=torch.broadcast_to( + incoming.particles, (*vector_shape, incoming.num_particles, 7) + ), + energy=torch.broadcast_to(incoming.energy, vector_shape), + particle_charges=torch.broadcast_to( + incoming.particle_charges, (*vector_shape, incoming.num_particles) + ), + survived_probabilities=torch.broadcast_to( + incoming.survived_probabilities, + (*vector_shape, incoming.num_particles), + ), + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) flattened_incoming = ParticleBeam( particles=vectorized_incoming.particles.flatten(end_dim=-3), @@ -575,11 +583,11 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=vectorized_incoming.particle_charges.flatten( end_dim=-2 ), - device=vectorized_incoming.particles.device, - dtype=vectorized_incoming.particles.dtype, survived_probabilities=vectorized_incoming.survived_probabilities.flatten( end_dim=-2 ), + device=vectorized_incoming.particles.device, + dtype=vectorized_incoming.particles.dtype, ) flattened_length_effect = self.effect_length.flatten(end_dim=-1) @@ -616,30 +624,25 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ..., 2 ] * dt.unsqueeze(-1) - if not is_incoming_vectorized: - # Reshape to the original non-vectorised shape - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.squeeze(0), - vectorized_incoming.energy.squeeze(0), - particle_charges=vectorized_incoming.particle_charges.squeeze(0), - device=vectorized_incoming.particles.device, - dtype=vectorized_incoming.particles.dtype, - survived_probabilities=vectorized_incoming.survived_probabilities.squeeze( - 0 - ), - ) - else: - # Reverse the flattening of the vector dimensions - outgoing = ParticleBeam.from_xyz_pxpypz( - xp_coordinates.unflatten( - dim=0, sizes=vectorized_incoming.particles.shape[:-2] - ), - vectorized_incoming.energy, - particle_charges=vectorized_incoming.particle_charges, - survived_probabilities=vectorized_incoming.survived_probabilities, - device=vectorized_incoming.particles.device, - dtype=vectorized_incoming.particles.dtype, - ) + # Reverse the flattening of the vector dimensions + outgoing_vector_shape = torch.broadcast_shapes( + incoming.particles.shape[:-2], + incoming.energy.shape, + incoming.particle_charges.shape[:-1], + incoming.survived_probabilities.shape[:-1], + self.effect_length.shape, + ) + outgoing = ParticleBeam.from_xyz_pxpypz( + xp_coordinates=xp_coordinates.reshape( + (*outgoing_vector_shape, incoming.num_particles, 7) + ), + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survived_probabilities=incoming.survived_probabilities, + device=incoming.particles.device, + dtype=incoming.particles.dtype, + ) + return outgoing else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 227ed410..1062959f 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -59,18 +59,17 @@ def __init__( ( particle_charges.to(**factory_kwargs) if particle_charges is not None - else torch.zeros(particles.shape[:2], **factory_kwargs) + else torch.zeros(particles.shape[-2], **factory_kwargs) ), ) self.register_buffer("energy", energy.to(**factory_kwargs)) - if survived_probabilities is not None: - # Try to broadcast the survival probability to the particles shape - survived_probabilities = survived_probabilities.expand(particles.shape[:-1]) - else: - # If no survival probability provided, default to all particles surviving - survived_probabilities = torch.ones(particles.shape[:-1], **factory_kwargs) self.register_buffer( - "survived_probabilities", survived_probabilities.to(**factory_kwargs) + "survived_probabilities", + ( + survived_probabilities.to(**factory_kwargs) + if survived_probabilities is not None + else torch.ones(particles.shape[-2], **factory_kwargs) + ), ) @classmethod From e729a066683ecdf7f02f36e50eac5fb2140595cc Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 18:37:04 +0100 Subject: [PATCH 13/29] More consistent argument order --- cheetah/accelerator/aperture.py | 6 +++--- cheetah/accelerator/cavity.py | 6 +++--- cheetah/accelerator/dipole.py | 2 +- cheetah/accelerator/drift.py | 2 +- cheetah/accelerator/element.py | 2 +- cheetah/accelerator/quadrupole.py | 8 +++++--- cheetah/accelerator/transverse_deflecting_cavity.py | 6 ++++-- 7 files changed, 18 insertions(+), 14 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index a90fdbac..0d5005f0 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -112,12 +112,12 @@ def track(self, incoming: Beam) -> Beam: # ] return ParticleBeam( - outgoing_particles, - incoming.energy, + particles=outgoing_particles, + energy=incoming.energy, particle_charges=incoming.particle_charges, + survived_probabilities=outgoing_survival, device=incoming.particles.device, dtype=incoming.particles.dtype, - survived_probabilities=outgoing_survival, ) def split(self, resolution: torch.Tensor) -> list[Element]: diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 02cc4c23..ee8e8943 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -244,12 +244,12 @@ def _track_beam(self, incoming: Beam) -> Beam: return outgoing else: # ParticleBeam outgoing = ParticleBeam( - outgoing_particles, - outgoing_energy, + particles=outgoing_particles, + energy=outgoing_energy, particle_charges=incoming.particle_charges, + survived_probabilities=incoming.survived_probabilities, device=outgoing_particles.device, dtype=outgoing_particles.dtype, - survived_probabilities=incoming.survived_probabilities, ) return outgoing diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index cb7e1133..05c54c1f 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -260,9 +260,9 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, + survived_probabilities=incoming.survived_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, - survived_probabilities=incoming.survived_probabilities, ) return outgoing_beam diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 1fd72aae..96b87740 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -114,9 +114,9 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, + survived_probabilities=incoming.survived_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, - survived_probabilities=incoming.survived_probabilities, ) return outgoing_beam diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index dbbf9a3a..4dc7e810 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -81,9 +81,9 @@ def track(self, incoming: Beam) -> Beam: new_particles, incoming.energy, particle_charges=incoming.particle_charges, + survived_probabilities=incoming.survived_probabilities, device=new_particles.device, dtype=new_particles.dtype, - survived_probabilities=incoming.survived_probabilities, ) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index dc546c65..ac7fad7c 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -187,12 +187,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, + survived_probabilities=incoming.survived_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, - survived_probabilities=incoming.survived_probabilities, ) return outgoing_beam diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index 0d621d93..a07f8d6a 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -207,8 +207,10 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, survived_probabilities=incoming.survived_probabilities, device=incoming.particles.device, From da13b301c299419e8b3e871fea6d55cbf37697c2 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 18:38:58 +0100 Subject: [PATCH 14/29] Fix `black` warning because of too long line --- cheetah/accelerator/space_charge_kick.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 84c29066..35b44718 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -583,8 +583,8 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=vectorized_incoming.particle_charges.flatten( end_dim=-2 ), - survived_probabilities=vectorized_incoming.survived_probabilities.flatten( - end_dim=-2 + survived_probabilities=( + vectorized_incoming.survived_probabilities.flatten(end_dim=-2) ), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, From 1caf5ff6eefb80e1b61f1278755d3262fde9f1ca Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 18:46:11 +0100 Subject: [PATCH 15/29] Clean up aperture test --- tests/test_vectorized.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index fa82f908..8ca85a5e 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -451,16 +452,16 @@ def test_broadcasting_solenoid_misalignment(): assert outgoing.energy.shape == (2,) -@pytest.mark.parametrize("shape", ["rectangular", "elliptical"]) -def test_vectorized_aperture_broadcasting(shape): +@pytest.mark.parametrize("aperture_shape", ["rectangular", "elliptical"]) +def test_vectorized_aperture_broadcasting(aperture_shape): """ Test that apertures work in a vectorised setting and that broadcasting rules are applied correctly. """ torch.manual_seed(0) + incoming = cheetah.ParticleBeam.from_parameters( num_particles=100_000, - # mu_x=torch.tensor(1e-4), sigma_py=torch.tensor(1e-4), sigma_px=torch.tensor(2e-4), energy=torch.tensor([154e6, 14e9]), @@ -471,7 +472,7 @@ def test_vectorized_aperture_broadcasting(shape): cheetah.Aperture( x_max=torch.tensor([[1e-5], [2e-4], [3e-4]]), y_max=torch.tensor(2e-4), - shape=shape, + shape=aperture_shape, ), cheetah.Drift(length=torch.tensor(0.5)), ] @@ -483,17 +484,11 @@ def test_vectorized_aperture_broadcasting(shape): assert outgoing.particle_charges.shape == (100_000,) assert outgoing.energy.shape == (2,) - if shape == "elliptical": - assert torch.allclose( - outgoing.survived_probabilities.sum(dim=-1)[:, 0], - torch.tensor( - [7672, 94523, 99547], dtype=outgoing.survived_probabilities.dtype - ), + if aperture_shape == "elliptical": + assert np.allclose( + outgoing.survived_probabilities.sum(dim=-1)[:, 0], [7672, 94523, 99547] ) - elif shape == "rectangular": - assert torch.allclose( - outgoing.survived_probabilities.sum(dim=-1)[:, 0], - torch.tensor( - [7935, 95400, 99719], dtype=outgoing.survived_probabilities.dtype - ), + elif aperture_shape == "rectangular": + assert np.allclose( + outgoing.survived_probabilities.sum(dim=-1)[:, 0], [7935, 95400, 99719] ) From 73657cdb9354690222c66d1476bdc24203f8e691 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 19:09:27 +0100 Subject: [PATCH 16/29] Fix `Aperture` test and clean up `Aperture` code --- cheetah/accelerator/aperture.py | 51 ++++++++++++--------------------- tests/test_vectorized.py | 6 ++-- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 0d5005f0..ad1dad2f 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -14,11 +14,12 @@ class Aperture(Element): """ Physical aperture. - The aperture is only considered if tracking a `ParticleBeam` and the aperture is - active. - :param x_max: half size horizontal offset in [m] - :param y_max: half size vertical offset in [m] + NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has + an effect when the aperture is active. + + :param x_max: half size horizontal offset in [m]. + :param y_max: half size vertical offset in [m]. :param shape: Shape of the aperture. Can be "rectangular" or "elliptical". :param is_active: If the aperture actually blocks particles. :param name: Unique identifier of the element. @@ -80,42 +81,28 @@ def track(self, incoming: Beam) -> Beam: "elliptical", ], f"Unknown aperture shape {self.shape}" - # broadcast x_max and y_max and the ParticleBeam to the same shape - vector_shape = torch.broadcast_shapes( - self.x_max.shape, - self.y_max.shape, - incoming.x.shape[:-1], - incoming.energy.shape, - ) - x_max = self.x_max.expand(vector_shape).unsqueeze(-1) - y_max = self.y_max.expand(vector_shape).unsqueeze(-1) - outgoing_particles = incoming.particles.expand(vector_shape + (-1, 7)) - if self.shape == "rectangular": survived_mask = torch.logical_and( - torch.logical_and(incoming.x > -x_max, incoming.x < x_max), - torch.logical_and(incoming.y > -y_max, incoming.y < y_max), + torch.logical_and( + incoming.x > -self.x_max.unsqueeze(-1), + incoming.x < self.x_max.unsqueeze(-1), + ), + torch.logical_and( + incoming.y > -self.y_max.unsqueeze(-1), + incoming.y < self.y_max.unsqueeze(-1), + ), ) elif self.shape == "elliptical": - survived_mask = (incoming.x**2 / x_max**2 + incoming.y**2 / y_max**2) <= 1.0 - - outgoing_survival = incoming.survived_probabilities * survived_mask - - # outgoing_particles = incoming.particles[survived_mask] - - # outgoing_particle_charges = incoming.particle_charges[survived_mask] - - # self.lost_particles = incoming.particles[torch.logical_not(survived_mask)] - - # self.lost_particle_charges = incoming.particle_charges[ - # torch.logical_not(survived_mask) - # ] + survived_mask = ( + incoming.x**2 / self.x_max.unsqueeze(-1) ** 2 + + incoming.y**2 / self.y_max.unsqueeze(-1) ** 2 + ) <= 1.0 return ParticleBeam( - particles=outgoing_particles, + particles=incoming.particles, energy=incoming.energy, particle_charges=incoming.particle_charges, - survived_probabilities=outgoing_survival, + survived_probabilities=incoming.survived_probabilities * survived_mask, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 8ca85a5e..8941a6d6 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -480,9 +480,11 @@ def test_vectorized_aperture_broadcasting(aperture_shape): outgoing = segment.track(incoming) - assert outgoing.particles.shape == (3, 2, 100_000, 7) - assert outgoing.particle_charges.shape == (100_000,) + # Particle positions are unaffected by the aperture ... only their survival is + assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.energy.shape == (2,) + assert outgoing.particle_charges.shape == (100_000,) + assert outgoing.survived_probabilities.shape == (3, 2, 100_000) if aperture_shape == "elliptical": assert np.allclose( From fa3f0a65a6e7f5e455999a08e9f510589872d056 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 20:13:13 +0100 Subject: [PATCH 17/29] Some code cleanup in `ParameterBeam` --- cheetah/particles/particle_beam.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 1062959f..8c7125a1 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -27,8 +27,9 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. :param particle_charges: Charges of macroparticles in the beam in C. - :param survived_probabilities: Survival probability of each particle in the beam. - Default to array with ones. (1: survive, 0: lost) + :param survived_probabilities: Vector of probabilities that each particle has + survived (i.e. not been lost), where 1.0 means the particle has survived and + 0.0 means the particle has been lost. Defaults to ones. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. :param dtype: Data type of the generated particles. @@ -890,24 +891,21 @@ def __len__(self) -> int: @property def total_charge(self) -> torch.Tensor: - """Returns the total charge of the beam in C. - Note: it does not take into account the survival of the particles. - """ - return torch.sum(self.particle_charges, dim=-1) - - @property - def total_charge_survived(self) -> torch.Tensor: - """Returns the total charge of the survived macroparticles, in C.""" + """Total charge of the beam in C, taking into account particle losses.""" return torch.sum(self.particle_charges * self.survived_probabilities, dim=-1) @property def num_particles(self) -> int: - """Length of the macroparticle array, does not account for lost.""" + """ + Length of the macroparticle array. + + NOTE: This does not account for lost particles. + """ return self.particles.shape[-2] @property def num_particles_survived(self) -> torch.Tensor: - """Returns the number of macroparticles that survived the simulation.""" + """Number of macroparticles that have survived.""" return self.survived_probabilities.sum(dim=-1) @property From 52116ab186c36c1962b332ac40be7a55c744832d Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 21:06:04 +0100 Subject: [PATCH 18/29] Remove `Beam.empty` --- cheetah/accelerator/bpm.py | 4 +- cheetah/accelerator/cavity.py | 4 +- cheetah/accelerator/element.py | 4 +- cheetah/accelerator/screen.py | 24 ++++++- cheetah/accelerator/space_charge_kick.py | 4 +- cheetah/particles/beam.py | 2 - cheetah/particles/particle_beam.py | 92 +++++++++--------------- 7 files changed, 59 insertions(+), 75 deletions(-) diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index d0e636bd..405e8d24 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -37,9 +37,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: ) def track(self, incoming: Beam) -> Beam: - if incoming is Beam.empty: - self.reading = None - elif isinstance(incoming, ParameterBeam): + if isinstance(incoming, ParameterBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) elif isinstance(incoming, ParticleBeam): self.reading = torch.stack([incoming.mu_x, incoming.mu_y]) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index ee8e8943..27a50d3b 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -103,9 +103,7 @@ def track(self, incoming: Beam) -> Beam: :param incoming: Beam of particles entering the element. :return: Beam of particles exiting the element. """ - if incoming is Beam.empty: - return incoming - elif isinstance(incoming, (ParameterBeam, ParticleBeam)): + if isinstance(incoming, (ParameterBeam, ParticleBeam)): return self._track_beam(incoming) else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 4dc7e810..1df3efc8 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -60,9 +60,7 @@ def track(self, incoming: Beam) -> Beam: :param incoming: Beam of particles entering the element. :return: Beam of particles exiting the element. """ - if incoming is Beam.empty: - return incoming - elif isinstance(incoming, ParameterBeam): + if isinstance(incoming, ParameterBeam): tm = self.transfer_map(incoming.energy) mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1) cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1))) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index e1d6c0ea..540a405c 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -158,6 +158,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: return torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1)) def track(self, incoming: Beam) -> Beam: + # Record the beam only when the screen is active if self.is_active: copy_of_incoming = deepcopy(incoming) @@ -185,7 +186,26 @@ def track(self, incoming: Beam) -> Beam: self.set_read_beam(copy_of_incoming) - return Beam.empty if self.is_blocking else incoming + # Block the beam only when the screen is active and blocking + if self.is_active and self.is_blocking: + if isinstance(incoming, ParameterBeam): + return ParameterBeam( + mu=incoming._mu, + cov=incoming._cov, + energy=incoming.energy, + total_charge=torch.zeros_like(incoming.total_charge), + ) + elif isinstance(incoming, ParticleBeam): + return ParticleBeam( + particles=incoming.particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survived_probabilities=torch.zeros_like( + incoming.survived_probabilities + ), + ) + else: + return deepcopy(incoming) @property def reading(self) -> torch.Tensor: @@ -194,7 +214,7 @@ def reading(self) -> torch.Tensor: return self.cached_reading read_beam = self.get_read_beam() - if read_beam is Beam.empty or read_beam is None: + if read_beam is None: image = torch.zeros( (int(self.effective_resolution[1]), int(self.effective_resolution[0])), device=self.misalignment.device, diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 35b44718..de523927 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -546,9 +546,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: :param incoming: Beam of particles entering the element. :returns: Beam of particles exiting the element. """ - if incoming is Beam.empty or incoming.particles.shape[0] == 0: - return incoming - elif isinstance(incoming, ParticleBeam): + if isinstance(incoming, ParticleBeam): # This flattening is a hack to only think about one vector dimension in the # following code. It is reversed at the end of the function. diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index 0e9b428b..5a331dc5 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -33,8 +33,6 @@ class directly, but use one of the subclasses. :math:`\Delta E = E - E_0` """ - empty = "I'm an empty beam!" - @classmethod @abstractmethod def from_parameters( diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 8c7125a1..843b32ff 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -26,7 +26,7 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. - :param particle_charges: Charges of macroparticles in the beam in C. + :param particle_charges: Charges of the macroparticles in the beam in C. :param survived_probabilities: Vector of probabilities that each particle has survived (i.e. not been lost), where 1.0 means the particle has survived and 0.0 means the particle has been lost. Defaults to ones. @@ -910,7 +910,7 @@ def num_particles_survived(self) -> torch.Tensor: @property def x(self) -> Optional[torch.Tensor]: - return self.particles[..., 0] if self is not Beam.empty else None + return self.particles[..., 0] @x.setter def x(self, value: torch.Tensor) -> None: @@ -918,28 +918,26 @@ def x(self, value: torch.Tensor) -> None: @property def mu_x(self) -> Optional[torch.Tensor]: - """Mean of the :math:`x` coordinates of the particles, weighted by their - survival probability.""" + """ + Mean of the :math:`x` coordinates of the particles, weighted by their + survival probability. + """ return ( torch.sum((self.x * self.survived_probabilities), dim=-1) / self.num_particles_survived - if self is not Beam.empty - else None ) @property def sigma_x(self) -> Optional[torch.Tensor]: - """Standard deviation of the :math:`x` coordinates of the particles, weighted - by their survival probability.""" - return ( - unbiased_weighted_std(self.x, self.survived_probabilities, dim=-1) - if self is not Beam.empty - else None - ) + """ + Standard deviation of the :math:`x` coordinates of the particles, weighted + by their survival probability. + """ + return unbiased_weighted_std(self.x, self.survived_probabilities, dim=-1) @property def px(self) -> Optional[torch.Tensor]: - return self.particles[..., 1] if self is not Beam.empty else None + return self.particles[..., 1] @px.setter def px(self, value: torch.Tensor) -> None: @@ -947,28 +945,26 @@ def px(self, value: torch.Tensor) -> None: @property def mu_px(self) -> Optional[torch.Tensor]: - """Mean of the :math:`px` coordinates of the particles, weighted by their - survival probability.""" + """ + Mean of the :math:`px` coordinates of the particles, weighted by their + survival probability. + """ return ( torch.sum((self.px * self.survived_probabilities), dim=-1) / self.num_particles_survived - if self is not Beam.empty - else None ) @property def sigma_px(self) -> Optional[torch.Tensor]: - """Standard deviation of the :math:`px` coordinates of the particles, weighted - by their survival probability.""" - return ( - unbiased_weighted_std(self.px, self.survived_probabilities, dim=-1) - if self is not Beam.empty - else None - ) + """ + Standard deviation of the :math:`px` coordinates of the particles, weighted + by their survival probability. + """ + return unbiased_weighted_std(self.px, self.survived_probabilities, dim=-1) @property def y(self) -> Optional[torch.Tensor]: - return self.particles[..., 2] if self is not Beam.empty else None + return self.particles[..., 2] @y.setter def y(self, value: torch.Tensor) -> None: @@ -979,21 +975,15 @@ def mu_y(self) -> Optional[float]: return ( torch.sum((self.y * self.survived_probabilities), dim=-1) / self.num_particles_survived - if self is not Beam.empty - else None ) @property def sigma_y(self) -> Optional[torch.Tensor]: - return ( - unbiased_weighted_std(self.y, self.survived_probabilities, dim=-1) - if self is not Beam.empty - else None - ) + return unbiased_weighted_std(self.y, self.survived_probabilities, dim=-1) @property def py(self) -> Optional[torch.Tensor]: - return self.particles[..., 3] if self is not Beam.empty else None + return self.particles[..., 3] @py.setter def py(self, value: torch.Tensor) -> None: @@ -1004,21 +994,15 @@ def mu_py(self) -> Optional[torch.Tensor]: return ( torch.sum((self.py * self.survived_probabilities), dim=-1) / self.num_particles_survived - if self is not Beam.empty - else None ) @property def sigma_py(self) -> Optional[torch.Tensor]: - return ( - unbiased_weighted_std(self.py, self.survived_probabilities, dim=-1) - if self is not Beam.empty - else None - ) + return unbiased_weighted_std(self.py, self.survived_probabilities, dim=-1) @property def tau(self) -> Optional[torch.Tensor]: - return self.particles[..., 4] if self is not Beam.empty else None + return self.particles[..., 4] @tau.setter def tau(self, value: torch.Tensor) -> None: @@ -1029,21 +1013,15 @@ def mu_tau(self) -> Optional[torch.Tensor]: return ( torch.sum((self.tau * self.survived_probabilities), dim=-1) / self.num_particles_survived - if self is not Beam.empty - else None ) @property def sigma_tau(self) -> Optional[torch.Tensor]: - return ( - unbiased_weighted_std(self.tau, self.survived_probabilities, dim=-1) - if self is not Beam.empty - else None - ) + return unbiased_weighted_std(self.tau, self.survived_probabilities, dim=-1) @property def p(self) -> Optional[torch.Tensor]: - return self.particles[..., 5] if self is not Beam.empty else None + return self.particles[..., 5] @p.setter def p(self, value: torch.Tensor) -> None: @@ -1054,21 +1032,16 @@ def mu_p(self) -> Optional[torch.Tensor]: return ( torch.sum((self.p * self.survived_probabilities), dim=-1) / self.num_particles_survived - if self is not Beam.empty - else None ) @property def sigma_p(self) -> Optional[torch.Tensor]: - return ( - unbiased_weighted_std(self.p, self.survived_probabilities, dim=-1) - if self is not Beam.empty - else None - ) + return unbiased_weighted_std(self.p, self.survived_probabilities, dim=-1) @property def sigma_xpx(self) -> torch.Tensor: - r"""Returns the covariance between x and px. :math:`\sigma_{x, px}^2`. + r""" + Returns the covariance between x and px. :math:`\sigma_{x, px}^2`. It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( @@ -1077,7 +1050,8 @@ def sigma_xpx(self) -> torch.Tensor: @property def sigma_ypy(self) -> torch.Tensor: - r"""Returns the covariance between y and py. :math:`\sigma_{y, py}^2`. + r""" + Returns the covariance between y and py. :math:`\sigma_{y, py}^2`. It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( From bd6515ffd148e4c839fa5f8f4c8a6efdfbe67df7 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 21:08:04 +0100 Subject: [PATCH 19/29] Fix `flake8` warning about unused import --- cheetah/accelerator/space_charge_kick.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index de523927..fac55439 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -5,7 +5,7 @@ from scipy.constants import elementary_charge, epsilon_0, speed_of_light from cheetah.accelerator.element import Element -from cheetah.particles import Beam, ParticleBeam +from cheetah.particles import ParticleBeam from cheetah.utils import verify_device_and_dtype From db6db835b668e25f5a140cb45961da5b13b8ccbc Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 21:14:05 +0100 Subject: [PATCH 20/29] Minor code readiblity improvement --- cheetah/particles/particle_beam.py | 42 +++++++++++++----------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 843b32ff..ca69de57 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -922,10 +922,9 @@ def mu_x(self) -> Optional[torch.Tensor]: Mean of the :math:`x` coordinates of the particles, weighted by their survival probability. """ - return ( - torch.sum((self.x * self.survived_probabilities), dim=-1) - / self.num_particles_survived - ) + return torch.sum( + (self.x * self.survived_probabilities), dim=-1 + ) / self.survived_probabilities.sum(dim=-1) @property def sigma_x(self) -> Optional[torch.Tensor]: @@ -949,10 +948,9 @@ def mu_px(self) -> Optional[torch.Tensor]: Mean of the :math:`px` coordinates of the particles, weighted by their survival probability. """ - return ( - torch.sum((self.px * self.survived_probabilities), dim=-1) - / self.num_particles_survived - ) + return torch.sum( + (self.px * self.survived_probabilities), dim=-1 + ) / self.survived_probabilities.sum(dim=-1) @property def sigma_px(self) -> Optional[torch.Tensor]: @@ -972,10 +970,9 @@ def y(self, value: torch.Tensor) -> None: @property def mu_y(self) -> Optional[float]: - return ( - torch.sum((self.y * self.survived_probabilities), dim=-1) - / self.num_particles_survived - ) + return torch.sum( + (self.y * self.survived_probabilities), dim=-1 + ) / self.survived_probabilities.sum(dim=-1) @property def sigma_y(self) -> Optional[torch.Tensor]: @@ -991,10 +988,9 @@ def py(self, value: torch.Tensor) -> None: @property def mu_py(self) -> Optional[torch.Tensor]: - return ( - torch.sum((self.py * self.survived_probabilities), dim=-1) - / self.num_particles_survived - ) + return torch.sum( + (self.py * self.survived_probabilities), dim=-1 + ) / self.survived_probabilities.sum(dim=-1) @property def sigma_py(self) -> Optional[torch.Tensor]: @@ -1010,10 +1006,9 @@ def tau(self, value: torch.Tensor) -> None: @property def mu_tau(self) -> Optional[torch.Tensor]: - return ( - torch.sum((self.tau * self.survived_probabilities), dim=-1) - / self.num_particles_survived - ) + return torch.sum( + (self.tau * self.survived_probabilities), dim=-1 + ) / self.survived_probabilities.sum(dim=-1) @property def sigma_tau(self) -> Optional[torch.Tensor]: @@ -1029,10 +1024,9 @@ def p(self, value: torch.Tensor) -> None: @property def mu_p(self) -> Optional[torch.Tensor]: - return ( - torch.sum((self.p * self.survived_probabilities), dim=-1) - / self.num_particles_survived - ) + return torch.sum( + (self.p * self.survived_probabilities), dim=-1 + ) / self.survived_probabilities.sum(dim=-1) @property def sigma_p(self) -> Optional[torch.Tensor]: From 7a538d3a15911ac06b3d34ac81edfea678093948 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 21:24:13 +0100 Subject: [PATCH 21/29] Some more readibility imporvements to the code --- cheetah/particles/particle_beam.py | 28 ++++++++++++++++++++-------- cheetah/utils/statistics.py | 14 ++++++-------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index ca69de57..9e161d64 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -932,7 +932,9 @@ def sigma_x(self) -> Optional[torch.Tensor]: Standard deviation of the :math:`x` coordinates of the particles, weighted by their survival probability. """ - return unbiased_weighted_std(self.x, self.survived_probabilities, dim=-1) + return unbiased_weighted_std( + self.x, weights=self.survived_probabilities, dim=-1 + ) @property def px(self) -> Optional[torch.Tensor]: @@ -958,7 +960,9 @@ def sigma_px(self) -> Optional[torch.Tensor]: Standard deviation of the :math:`px` coordinates of the particles, weighted by their survival probability. """ - return unbiased_weighted_std(self.px, self.survived_probabilities, dim=-1) + return unbiased_weighted_std( + self.px, weights=self.survived_probabilities, dim=-1 + ) @property def y(self) -> Optional[torch.Tensor]: @@ -976,7 +980,9 @@ def mu_y(self) -> Optional[float]: @property def sigma_y(self) -> Optional[torch.Tensor]: - return unbiased_weighted_std(self.y, self.survived_probabilities, dim=-1) + return unbiased_weighted_std( + self.y, weights=self.survived_probabilities, dim=-1 + ) @property def py(self) -> Optional[torch.Tensor]: @@ -994,7 +1000,9 @@ def mu_py(self) -> Optional[torch.Tensor]: @property def sigma_py(self) -> Optional[torch.Tensor]: - return unbiased_weighted_std(self.py, self.survived_probabilities, dim=-1) + return unbiased_weighted_std( + self.py, weights=self.survived_probabilities, dim=-1 + ) @property def tau(self) -> Optional[torch.Tensor]: @@ -1012,7 +1020,9 @@ def mu_tau(self) -> Optional[torch.Tensor]: @property def sigma_tau(self) -> Optional[torch.Tensor]: - return unbiased_weighted_std(self.tau, self.survived_probabilities, dim=-1) + return unbiased_weighted_std( + self.tau, weights=self.survived_probabilities, dim=-1 + ) @property def p(self) -> Optional[torch.Tensor]: @@ -1030,7 +1040,9 @@ def mu_p(self) -> Optional[torch.Tensor]: @property def sigma_p(self) -> Optional[torch.Tensor]: - return unbiased_weighted_std(self.p, self.survived_probabilities, dim=-1) + return unbiased_weighted_std( + self.p, weights=self.survived_probabilities, dim=-1 + ) @property def sigma_xpx(self) -> torch.Tensor: @@ -1039,7 +1051,7 @@ def sigma_xpx(self) -> torch.Tensor: It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( - self.x, self.px, self.survived_probabilities, dim=-1 + self.x, self.px, weights=self.survived_probabilities, dim=-1 ) @property @@ -1049,7 +1061,7 @@ def sigma_ypy(self) -> torch.Tensor: It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( - self.y, self.py, self.survived_probabilities, dim=-1 + self.y, self.py, weights=self.survived_probabilities, dim=-1 ) @property diff --git a/cheetah/utils/statistics.py b/cheetah/utils/statistics.py index 8e5a3594..adfbeba2 100644 --- a/cheetah/utils/statistics.py +++ b/cheetah/utils/statistics.py @@ -4,14 +4,14 @@ def unbiased_weighted_covariance( input1: torch.Tensor, input2: torch.Tensor, weights: torch.Tensor, dim: int = None ) -> torch.Tensor: - """Compute the unbiased weighted covariance of two tensors. - inputs and weights should be broadcastable. + """ + Compute the unbiased weighted covariance of two tensors. - :param input1: Input tensor 1. (batch_size, sample_size) - :param input2: Input tensor 2. (batch_size, sample_size) - :param weights: Weights tensor. (batch_size, sample_size) + :param input1: Input tensor 1. (..., sample_size) + :param input2: Input tensor 2. (..., sample_size) + :param weights: Weights tensor. (..., sample_size) :param dim: Dimension along which to compute the covariance. - :return: Unbiased weighted covariance. (batch_size, 2, 2) + :return: Unbiased weighted covariance. (..., 2, 2) """ weighted_mean1 = torch.sum(input1 * weights, dim=dim) / torch.sum(weights, dim=dim) weighted_mean2 = torch.sum(input2 * weights, dim=dim) / torch.sum(weights, dim=dim) @@ -32,7 +32,6 @@ def unbiased_weighted_variance( ) -> torch.Tensor: """ Compute the unbiased weighted variance of a tensor. - inputs and weights should be broadcastable. :param input: Input tensor. :param weights: Weights tensor. @@ -54,7 +53,6 @@ def unbiased_weighted_std( ) -> torch.Tensor: """ Compute the unbiased weighted standard deviation of a tensor. - inputs and weights should be broadcastable. :param input: Input tensor. :param weights: Weights tensor. From 9e1189f01a257f2257c79a881acb0105bdbe2c25 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 21:38:35 +0100 Subject: [PATCH 22/29] Clean up statistics tests --- tests/test_statistics.py | 73 ++++++++++++++++++++++++++++ tests/test_statistics_calculation.py | 61 ----------------------- 2 files changed, 73 insertions(+), 61 deletions(-) create mode 100644 tests/test_statistics.py delete mode 100644 tests/test_statistics_calculation.py diff --git a/tests/test_statistics.py b/tests/test_statistics.py new file mode 100644 index 00000000..3b603851 --- /dev/null +++ b/tests/test_statistics.py @@ -0,0 +1,73 @@ +import torch + +from cheetah.utils import unbiased_weighted_covariance, unbiased_weighted_variance + + +def test_unbiased_weighted_variance_with_single_element(): + """Test that the variance is NaN when there is only one element.""" + data = torch.tensor([42.0]) + weights = torch.tensor([1.0]) + + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.isnan(computed_variance) + + +def test_unbiased_weighted_variance_with_same_weights(): + """ + Test that the weighted variance with all weights the same equals the unweighted + variance implementated in PyTorch. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + + expected_variance = torch.var(data, unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_different_weights(): + """Test that the variance is computed when some weights are different.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 0.5, 0.5, 0, 0]) + + expected_variance = torch.var(torch.tensor([1.0, 2.0, 3.0]), unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_variance_with_zero_weights(): + """Test that the variance is NaN when all weights are zero.""" + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) + + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.isnan(computed_variance) + + +def test_unbiased_weighted_variance_with_small_numbers(): + """Test that the variance is correct for small numbers.""" + data = torch.tensor([1e-10, 2e-10, 3e-10, 4e-10, 5e-10], dtype=torch.float32) + weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32) + + expected_variance = torch.var(data, unbiased=True) + computed_variance = unbiased_weighted_variance(data, weights) + + assert torch.allclose(computed_variance, expected_variance) + + +def test_unbiased_weighted_covariance_reduced_to_variance(): + """ + Test that the covariance computation is correctly reduced to the variance when both + inputs are the same. + """ + data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + weights = torch.tensor([0.5, 1.0, 1.0, 0.9, 0.9]) + + variance = unbiased_weighted_variance(data, weights) + covariance = unbiased_weighted_covariance(data, data, weights) + + assert torch.allclose(covariance, variance) diff --git a/tests/test_statistics_calculation.py b/tests/test_statistics_calculation.py deleted file mode 100644 index c7268a59..00000000 --- a/tests/test_statistics_calculation.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch - -from cheetah.utils import unbiased_weighted_covariance, unbiased_weighted_variance - - -def test_unbiased_weighted_variance_with_same_weights(): - """Test that the variance is calculated correctly with equal weights.""" - data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) - weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) - expected_variance = torch.var(data, unbiased=True) - calculated_variance = unbiased_weighted_variance(data, weights) - assert torch.allclose(calculated_variance, expected_variance) - - -def test_unbiased_weighted_variance_with_single_element(): - """Test that the variance is nan when there is only one element.""" - data = torch.tensor([42.0]) - weights = torch.tensor([1.0]) - assert torch.isnan(unbiased_weighted_variance(data, weights)) - - -def test_unbiased_weighted_variance_with_different_weights(): - """Test that the variance is calculated correctly with different weights.""" - data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) - weights = torch.tensor([0.5, 0.5, 0.5, 0, 0]) - expected_variance = torch.var(torch.tensor([1.0, 2.0, 3.0]), unbiased=True) - calculated_variance = unbiased_weighted_variance(data, weights) - assert torch.allclose(calculated_variance, expected_variance) - - -def test_unbiased_weighted_variance_with_zero_weights(): - """Test that the variance is nan when all weights are zero.""" - data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) - weights = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) - assert torch.isnan(unbiased_weighted_variance(data, weights)) - - -def test_unbiased_weighted_variance_with_small_numbers(): - """Test that the variance is calculated correctly with small numbers.""" - data = torch.tensor([1e-10, 2e-10, 3e-10, 4e-10, 5e-10], dtype=torch.float32) - weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32) - expected_variance = torch.var(data, unbiased=True) - calculated_variance = unbiased_weighted_variance(data, weights) - assert torch.allclose(calculated_variance, expected_variance) - - -def test_unbiased_weighted_covariance_reduced_to_variance(): - """Test that the covariance calculation is reduced to the variance when both inputs - are the same. - """ - data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) - equal_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) - expected_variance = torch.var(data, unbiased=True) - calculated_covariance = unbiased_weighted_covariance(data, data, equal_weights) - assert torch.allclose(calculated_covariance, expected_variance) - - different_weights = torch.tensor([0.5, 1.0, 1.0, 0.9, 0.9]) - assert torch.allclose( - unbiased_weighted_covariance(data, data, different_weights), - unbiased_weighted_variance(data, different_weights), - ) From f1ca213dd7e833866185e5152646b1b5cb169bf2 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 21:42:23 +0100 Subject: [PATCH 23/29] Refactor assertions for readability in test_compare_ocelot.py --- tests/test_compare_ocelot.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index beb3bfa4..bc1ab083 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -223,14 +223,8 @@ def test_aperture_elliptical(): == outgoing_p_array.rparticles.shape[1] ) - assert np.allclose( - outgoing_beam.mu_x.cpu().numpy(), - outgoing_p_array.x().mean(), - ) - assert np.allclose( - outgoing_beam.mu_px.cpu().numpy(), - outgoing_p_array.px().mean(), - ) + assert np.allclose(outgoing_beam.mu_x.cpu().numpy(), outgoing_p_array.x().mean()) + assert np.allclose(outgoing_beam.mu_px.cpu().numpy(), outgoing_p_array.px().mean()) def test_solenoid(): From 2e44fb5d9e04f2eb8b84c09076d7f556bf140c0b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 22:00:24 +0100 Subject: [PATCH 24/29] Properly fix Bmad-X test beam --- tests/resources/bmadx/incoming.pt | Bin 642366 -> 722584 bytes tests/test_dipole.py | 22 ++++++++++----------- tests/test_drift.py | 7 ++----- tests/test_quadrupole.py | 7 ++----- tests/test_transverse_deflecting_cavity.py | 13 +++++------- 5 files changed, 19 insertions(+), 30 deletions(-) diff --git a/tests/resources/bmadx/incoming.pt b/tests/resources/bmadx/incoming.pt index c901e8aaf5abc3383d723b7d55ec623fac1a7f0e..94b924045fa4b3a0d17047b0faa68e42dc09ecee 100644 GIT binary patch delta 82063 zcmeI5-)~c87{^c7vi1zdHVmu`!jA!CAkbg?DMDw5)6i(BU%pzJn~6;2+}xS@E3;#%R_i_WYh`GvXTjN)30anv zovcky7b;b!dBF+xW^8$DK38!|`Po9%39bCuecA5q%>?D`s#~t)W(rPAUS8yjEf>nQ z$=SlY=G_)&d&VaRMV%cPzkJQl*PPai^Bq3Rb#}g!w%%$#o(?WKZRudzx8$_nTKT5S z*Wu-S-Mw=8VoTAS9dq^+TjfyK=``ouaKdTd4QKDIVu$yOFTB`zaNIdiY)$)I=b#*W z;)EP>V!}D(I)^8m?#vz)aIJE!bgeLzov)N9bCadnl3S8*>PWHiDhcQ4`jE#qhU`&8 z9yizPbsf3t$y#QgTrlhXznj9rv-=Ysc>|M1Dz_3JTHK3B_AHB~Ba zEI4o0CeMtggP})DmkN)i2p;?E<7@W*u3grl&0?JV&@GfJQ-#V{YQN>J1onF)foNx& zb#C(zRj-lgSZY}9k9fBOEg!E2ZmtG=vNg%(mrebBn{3Up1@Eo~LU%rqi;O;Zk=R%& zxp9&9o2$#AyQ|CgouL;T8E-tYzIxA&%Ry(KyIf){HMnuPP+cz9y!a$3n#nKF*G1I8 z>*FrMzH<|Rj1N^njgJgKj}HvMjE@UoF{u)OO9A=ELy4T!FPU;ui2paN!m zTmbC7Ep#2vZ z0R127e(eA0|48>^|4;u1nvaw|`6vHK1;{`7M=HSnU+4!u10w&(>ZkvcfAWu1fY1Mc z=9AR(A2a|x{{Wu>oAGf0(EsWGNCoKsNcXdzkbh+Lv;Qam5i@{d$N)gP|<>Hp*(S^cX1kk!xoe_j8; z6d3++0g!+4k5oX_AFlf8|KuN8{q%qG4_thbnV0;Nf20EBpZp^g;OGCL?YI8?J^O#0 z0`z~R`?3G0|0CUx{XhMm{*P3^_%Hn*r2zRyR=-()N9A6PtN8;4K>k(#gL?)<{>eX5 z0rF4&kqSut0o{-OPyUhBFZBnlepUb20LVYE`IOX?fAWu1fc%qxqyqf>FZ~~>fc9Tt z0Q7&P`?3G0|0CUx{XhL5Xg*T<eX50Y3i& znom;Cf6xH*`~!RjY{th0K>w%zBNd?kBi)bx|0n$)YQOyXe~mx=AEf~KC;vzV*#9S? zGa&g#RzL6m$v^o=Dxm6*_y0Hr$Un0BmH*&!uNoi8KQsXHud;wk0r@BYNCn70`9~_i z&;J6=M-or|k=4(C|Bw8Wf20Cx{$u}-Q-J&{!jkNKT-kq z|3LGRzJvTDtDo!tY6&7dQoT{g3N^ zNCn70`9~_i{vT*Q(sz)5Wc73XkNlH=qyoJE2i|;2;>ka<`sx4VpZp^g(ET?s0QUbf z{{p9AuKyudzjXfr4dDOA|9CnW%8mZ@*+)Y@OKtUuHM=X^Zl$c!T)up@G&d8uR;W}< z<+-ucdCOY~M7-O9w))7LeQL)+EB5TD`ofyMx9PMx;&e1Rm>B3A8i@AAVzI&4@ZdmV zAR6oI?~4x&$D*;pc)Ty(Kal8;tHB?B|HJ$BCu?@N=h4!o!ehT>sjb7z0u6s|gi}x4 zw~ux0YS?30(SWrX$o2bnPgmQP{a-Zh|Mk9o8bz# delta 1289 zcmZ{kJ#5oJ6vyp2j+3~g4QiWy(U0^iY171UzT6fHjY@`EKpGJ80#@R z0qM^Gn_$Nl`yfr9R;rdF~WGN2`B zNQ!est>udH?S|J0Is_a-47eZ=@GX^Vz_nU-c}C8g(FNm>7^D_W`?B$x7^SY5 z_LA|97@NCl+SeMb-%|56)Z=B-nXKi4X`Zrwm=(Sr!v-D`^?iaK8u4JRwtbd+pvfxQ zPN}35Gq$Y@gWVX{#-P0zVp7RyeVL$Rk4T-wb<(1fgicy@5`Nlrk~$&nM^B)5>0jb- zD!G8fi~526t@Ct3uV)3)TOz@b%ijG%G;{*I0eIDSGL ZjWlmFL!aL_zEL>MlQtBi;Dz`G`y1e2OlAN8 diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 56509fe0..231e7ddf 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -117,26 +117,24 @@ def test_dipole_bmadx_tracking(dtype): Test that the results of tracking through a dipole with the `"bmadx"` tracking method match the results from Bmad-X. """ - bmad_loaded = torch.load( - "tests/resources/bmadx/incoming.pt", weights_only=False - ).to(dtype) - incoming = ParticleBeam( - particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype + incoming = torch.load("tests/resources/bmadx/incoming.pt", weights_only=False).to( + dtype ) - angle = torch.tensor([20 * torch.pi / 180], dtype=dtype) + # TODO: See if Bmad-X test dtypes can be cleaned up now that dtype PR was merged + angle = torch.tensor(20 * torch.pi / 180, dtype=dtype) e1 = angle / 2 e2 = angle - e1 dipole_cheetah_bmadx = Dipole( - length=torch.tensor([0.5]), + length=torch.tensor(0.5), angle=angle, e1=e1, e2=e2, - tilt=torch.tensor([0.1], dtype=dtype), - fringe_integral=torch.tensor([0.5]), - fringe_integral_exit=torch.tensor([0.5]), - gap=torch.tensor([0.05], dtype=dtype), - gap_exit=torch.tensor([0.05], dtype=dtype), + tilt=torch.tensor(0.1, dtype=dtype), + fringe_integral=torch.tensor(0.5), + fringe_integral_exit=torch.tensor(0.5), + gap=torch.tensor(0.05, dtype=dtype), + gap_exit=torch.tensor(0.05, dtype=dtype), fringe_at="both", fringe_type="linear_edge", tracking_method="bmadx", diff --git a/tests/test_drift.py b/tests/test_drift.py index cffb05b7..bb5ba4b8 100644 --- a/tests/test_drift.py +++ b/tests/test_drift.py @@ -67,14 +67,11 @@ def test_drift_bmadx_tracking(dtype): Test that the results of tracking through a drift with the `"bmadx"` tracking method match the results from Bmad-X. """ - bmad_loaded = torch.load( + incoming_beam = torch.load( "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) - incoming_beam = cheetah.ParticleBeam( - particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype - ) drift = cheetah.Drift( - length=torch.tensor([1.0]), tracking_method="bmadx", dtype=dtype + length=torch.tensor(1.0), tracking_method="bmadx", dtype=dtype ) # Run tracking diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index eab8097c..dbdedac3 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -175,11 +175,8 @@ def test_quadrupole_bmadx_tracking(dtype): Test that the results of tracking through a quadrupole with the `"bmadx"` tracking method match the results from Bmad-X. """ - bmad_loaded = torch.load( - "tests/resources/bmadx/incoming.pt", weights_only=False - ).to(dtype) - incoming = ParticleBeam( - particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype + incoming = torch.load("tests/resources/bmadx/incoming.pt", weights_only=False).to( + dtype ) quadrupole = Quadrupole( length=torch.tensor(1.0), diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index 475d24a8..a9662aa8 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -10,17 +10,14 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): Test that the results of tracking through a TDC with the `"bmadx"` tracking method match the results from Bmad-X. """ - bmad_loaded = torch.load( + incoming_beam = torch.load( "tests/resources/bmadx/incoming.pt", weights_only=False ).to(dtype) - incoming_beam = cheetah.ParticleBeam( - particles=bmad_loaded.particles, energy=bmad_loaded.energy, dtype=dtype - ) tdc = cheetah.TransverseDeflectingCavity( - length=torch.tensor([1.0]), - voltage=torch.tensor([1e7]), - phase=torch.tensor([0.2], dtype=dtype), - frequency=torch.tensor([1e9]), + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor(0.2, dtype=dtype), + frequency=torch.tensor(1e9), tracking_method="bmadx", dtype=dtype, ) From 3f18d18b91c127f6dd3abd91f3d410446d731bf9 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 22:04:52 +0100 Subject: [PATCH 25/29] Rename to surval probabilities properties --- CHANGELOG.md | 2 +- cheetah/accelerator/aperture.py | 2 +- cheetah/accelerator/cavity.py | 2 +- cheetah/accelerator/dipole.py | 2 +- cheetah/accelerator/drift.py | 2 +- cheetah/accelerator/element.py | 2 +- cheetah/accelerator/quadrupole.py | 2 +- cheetah/accelerator/screen.py | 4 +- cheetah/accelerator/space_charge_kick.py | 16 ++--- .../transverse_deflecting_cavity.py | 2 +- cheetah/particles/particle_beam.py | 58 +++++++++--------- tests/resources/bmadx/incoming.pt | Bin 722584 -> 722584 bytes tests/test_space_charge_kick.py | 2 +- tests/test_vectorized.py | 6 +- 14 files changed, 51 insertions(+), 51 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77afef54..e0493955 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) - The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) -- Rework the `Aperture` element. Now `ParticleBeam` has a `survived_probabilites` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu, @jank324) +- Rework the `Aperture` element. Now `ParticleBeam` has a `survival_probabilities` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu, @jank324) ### 🚀 Features diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index ad1dad2f..71f1b329 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -102,7 +102,7 @@ def track(self, incoming: Beam) -> Beam: particles=incoming.particles, energy=incoming.energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities * survived_mask, + survival_probabilities=incoming.survival_probabilities * survived_mask, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 27a50d3b..8e4f9f5e 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -245,7 +245,7 @@ def _track_beam(self, incoming: Beam) -> Beam: particles=outgoing_particles, energy=outgoing_energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=outgoing_particles.device, dtype=outgoing_particles.dtype, ) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 05c54c1f..393c65f6 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -260,7 +260,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 96b87740..ff63f371 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -114,7 +114,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 1df3efc8..e050d106 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -79,7 +79,7 @@ def track(self, incoming: Beam) -> Beam: new_particles, incoming.energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=new_particles.device, dtype=new_particles.dtype, ) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index ac7fad7c..99ee3c0a 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -192,7 +192,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 540a405c..c7a56331 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -200,8 +200,8 @@ def track(self, incoming: Beam) -> Beam: particles=incoming.particles, energy=incoming.energy, particle_charges=incoming.particle_charges, - survived_probabilities=torch.zeros_like( - incoming.survived_probabilities + survival_probabilities=torch.zeros_like( + incoming.survival_probabilities ), ) else: diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index fac55439..74808395 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -149,7 +149,7 @@ def _deposit_charge_on_grid( ) # Accumulate the charge contributions - survived_particle_charges = beam.particle_charges * beam.survived_probabilities + survived_particle_charges = beam.particle_charges * beam.survival_probabilities repeated_charges = survived_particle_charges.repeat_interleave( repeats=8, dim=-1 ) # Shape:(..., 8 * num_particles) @@ -556,7 +556,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: incoming.particles.shape[:-2], incoming.energy.shape, incoming.particle_charges.shape[:-1], - incoming.survived_probabilities.shape[:-1], + incoming.survival_probabilities.shape[:-1], (1,), ) vectorized_incoming = ParticleBeam( @@ -567,8 +567,8 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=torch.broadcast_to( incoming.particle_charges, (*vector_shape, incoming.num_particles) ), - survived_probabilities=torch.broadcast_to( - incoming.survived_probabilities, + survival_probabilities=torch.broadcast_to( + incoming.survival_probabilities, (*vector_shape, incoming.num_particles), ), device=incoming.particles.device, @@ -581,8 +581,8 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: particle_charges=vectorized_incoming.particle_charges.flatten( end_dim=-2 ), - survived_probabilities=( - vectorized_incoming.survived_probabilities.flatten(end_dim=-2) + survival_probabilities=( + vectorized_incoming.survival_probabilities.flatten(end_dim=-2) ), device=vectorized_incoming.particles.device, dtype=vectorized_incoming.particles.dtype, @@ -627,7 +627,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: incoming.particles.shape[:-2], incoming.energy.shape, incoming.particle_charges.shape[:-1], - incoming.survived_probabilities.shape[:-1], + incoming.survival_probabilities.shape[:-1], self.effect_length.shape, ) outgoing = ParticleBeam.from_xyz_pxpypz( @@ -636,7 +636,7 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=incoming.energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index a07f8d6a..6f33417d 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -212,7 +212,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ), energy=ref_energy, particle_charges=incoming.particle_charges, - survived_probabilities=incoming.survived_probabilities, + survival_probabilities=incoming.survival_probabilities, device=incoming.particles.device, dtype=incoming.particles.dtype, ) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 9e161d64..23b078a2 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -27,7 +27,7 @@ class ParticleBeam(Beam): :param particles: List of 7-dimensional particle vectors. :param energy: Reference energy of the beam in eV. :param particle_charges: Charges of the macroparticles in the beam in C. - :param survived_probabilities: Vector of probabilities that each particle has + :param survival_probabilities: Vector of probabilities that each particle has survived (i.e. not been lost), where 1.0 means the particle has survived and 0.0 means the particle has been lost. Defaults to ones. :param device: Device to move the beam's particle array to. If set to `"auto"` a @@ -40,7 +40,7 @@ def __init__( particles: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, - survived_probabilities: Optional[torch.Tensor] = None, + survival_probabilities: Optional[torch.Tensor] = None, device=None, dtype=None, ) -> None: @@ -65,10 +65,10 @@ def __init__( ) self.register_buffer("energy", energy.to(**factory_kwargs)) self.register_buffer( - "survived_probabilities", + "survival_probabilities", ( - survived_probabilities.to(**factory_kwargs) - if survived_probabilities is not None + survival_probabilities.to(**factory_kwargs) + if survival_probabilities is not None else torch.ones(particles.shape[-2], **factory_kwargs) ), ) @@ -812,7 +812,7 @@ def from_xyz_pxpypz( xp_coordinates: torch.Tensor, energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, - survived_probabilities: Optional[torch.Tensor] = None, + survival_probabilities: Optional[torch.Tensor] = None, device=None, dtype=torch.float32, ) -> torch.Tensor: @@ -825,7 +825,7 @@ def from_xyz_pxpypz( particles=xp_coordinates.clone(), energy=energy, particle_charges=particle_charges, - survived_probabilities=survived_probabilities, + survival_probabilities=survival_probabilities, device=device, dtype=dtype, ) @@ -892,7 +892,7 @@ def __len__(self) -> int: @property def total_charge(self) -> torch.Tensor: """Total charge of the beam in C, taking into account particle losses.""" - return torch.sum(self.particle_charges * self.survived_probabilities, dim=-1) + return torch.sum(self.particle_charges * self.survival_probabilities, dim=-1) @property def num_particles(self) -> int: @@ -906,7 +906,7 @@ def num_particles(self) -> int: @property def num_particles_survived(self) -> torch.Tensor: """Number of macroparticles that have survived.""" - return self.survived_probabilities.sum(dim=-1) + return self.survival_probabilities.sum(dim=-1) @property def x(self) -> Optional[torch.Tensor]: @@ -923,8 +923,8 @@ def mu_x(self) -> Optional[torch.Tensor]: survival probability. """ return torch.sum( - (self.x * self.survived_probabilities), dim=-1 - ) / self.survived_probabilities.sum(dim=-1) + (self.x * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_x(self) -> Optional[torch.Tensor]: @@ -933,7 +933,7 @@ def sigma_x(self) -> Optional[torch.Tensor]: by their survival probability. """ return unbiased_weighted_std( - self.x, weights=self.survived_probabilities, dim=-1 + self.x, weights=self.survival_probabilities, dim=-1 ) @property @@ -951,8 +951,8 @@ def mu_px(self) -> Optional[torch.Tensor]: survival probability. """ return torch.sum( - (self.px * self.survived_probabilities), dim=-1 - ) / self.survived_probabilities.sum(dim=-1) + (self.px * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_px(self) -> Optional[torch.Tensor]: @@ -961,7 +961,7 @@ def sigma_px(self) -> Optional[torch.Tensor]: by their survival probability. """ return unbiased_weighted_std( - self.px, weights=self.survived_probabilities, dim=-1 + self.px, weights=self.survival_probabilities, dim=-1 ) @property @@ -975,13 +975,13 @@ def y(self, value: torch.Tensor) -> None: @property def mu_y(self) -> Optional[float]: return torch.sum( - (self.y * self.survived_probabilities), dim=-1 - ) / self.survived_probabilities.sum(dim=-1) + (self.y * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_y(self) -> Optional[torch.Tensor]: return unbiased_weighted_std( - self.y, weights=self.survived_probabilities, dim=-1 + self.y, weights=self.survival_probabilities, dim=-1 ) @property @@ -995,13 +995,13 @@ def py(self, value: torch.Tensor) -> None: @property def mu_py(self) -> Optional[torch.Tensor]: return torch.sum( - (self.py * self.survived_probabilities), dim=-1 - ) / self.survived_probabilities.sum(dim=-1) + (self.py * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_py(self) -> Optional[torch.Tensor]: return unbiased_weighted_std( - self.py, weights=self.survived_probabilities, dim=-1 + self.py, weights=self.survival_probabilities, dim=-1 ) @property @@ -1015,13 +1015,13 @@ def tau(self, value: torch.Tensor) -> None: @property def mu_tau(self) -> Optional[torch.Tensor]: return torch.sum( - (self.tau * self.survived_probabilities), dim=-1 - ) / self.survived_probabilities.sum(dim=-1) + (self.tau * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_tau(self) -> Optional[torch.Tensor]: return unbiased_weighted_std( - self.tau, weights=self.survived_probabilities, dim=-1 + self.tau, weights=self.survival_probabilities, dim=-1 ) @property @@ -1035,13 +1035,13 @@ def p(self, value: torch.Tensor) -> None: @property def mu_p(self) -> Optional[torch.Tensor]: return torch.sum( - (self.p * self.survived_probabilities), dim=-1 - ) / self.survived_probabilities.sum(dim=-1) + (self.p * self.survival_probabilities), dim=-1 + ) / self.survival_probabilities.sum(dim=-1) @property def sigma_p(self) -> Optional[torch.Tensor]: return unbiased_weighted_std( - self.p, weights=self.survived_probabilities, dim=-1 + self.p, weights=self.survival_probabilities, dim=-1 ) @property @@ -1051,7 +1051,7 @@ def sigma_xpx(self) -> torch.Tensor: It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( - self.x, self.px, weights=self.survived_probabilities, dim=-1 + self.x, self.px, weights=self.survival_probabilities, dim=-1 ) @property @@ -1061,7 +1061,7 @@ def sigma_ypy(self) -> torch.Tensor: It is weighted by the survival probability of the particles. """ return unbiased_weighted_covariance( - self.y, self.py, weights=self.survived_probabilities, dim=-1 + self.y, self.py, weights=self.survival_probabilities, dim=-1 ) @property diff --git a/tests/resources/bmadx/incoming.pt b/tests/resources/bmadx/incoming.pt index 94b924045fa4b3a0d17047b0faa68e42dc09ecee..5279a3d6db98f4542ff999f4ce340cbb472653a7 100644 GIT binary patch delta 130 zcmbQyt23ilXTwWIro^1huNf_wxNaFvjefw)!0>>1`UGZfnPxARb}tr2AZ7w$W*`Rf zS%H`hh}nUd1Bf}dd$Dkx{V!=`Y+z<=Zf;>_VQy(+Y+z^-;LXnAJTc?)bORP{2^Nsq S+kIHL*D|tz1^=>g&jtVnizcT4 delta 130 zcmbQyt23ilXTwWIrqqg&jtV`!zZW! diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 6faee89f..949e38ec 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -315,4 +315,4 @@ def test_space_charge_with_aperture_cutoff(): ) # Check that the number of surviving particles is less than the initial number - assert outgoing_beam_with_aperture.survived_probabilities.sum(dim=-1).max() < 10_000 + assert outgoing_beam_with_aperture.survival_probabilities.sum(dim=-1).max() < 10_000 diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 8941a6d6..aa00598a 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -484,13 +484,13 @@ def test_vectorized_aperture_broadcasting(aperture_shape): assert outgoing.particles.shape == (2, 100_000, 7) assert outgoing.energy.shape == (2,) assert outgoing.particle_charges.shape == (100_000,) - assert outgoing.survived_probabilities.shape == (3, 2, 100_000) + assert outgoing.survival_probabilities.shape == (3, 2, 100_000) if aperture_shape == "elliptical": assert np.allclose( - outgoing.survived_probabilities.sum(dim=-1)[:, 0], [7672, 94523, 99547] + outgoing.survival_probabilities.sum(dim=-1)[:, 0], [7672, 94523, 99547] ) elif aperture_shape == "rectangular": assert np.allclose( - outgoing.survived_probabilities.sum(dim=-1)[:, 0], [7935, 95400, 99719] + outgoing.survival_probabilities.sum(dim=-1)[:, 0], [7935, 95400, 99719] ) From fdfe7172534f25cbb363102c9c2203b364d45b19 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 22:14:23 +0100 Subject: [PATCH 26/29] Minor test code cleanup --- tests/test_space_charge_kick.py | 36 ++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 949e38ec..dad61da0 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -281,9 +281,24 @@ def test_space_charge_with_ares_astra_beam(): def test_space_charge_with_aperture_cutoff(): """ - Tests that the space charge kick is correctly calculated only for the surviving - particles when an aperture is used. + Tests that the space charge kick is correctly applied only to surviving particles, + by comparing the results with and without an aperture that results in beam losses. """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(0.2)), + cheetah.Aperture( + x_max=torch.tensor(1e-4), + y_max=torch.tensor(1e-4), + shape="rectangular", + is_active="False", + name="aperture", + ), + cheetah.Drift(length=torch.tensor(0.25)), + cheetah.SpaceChargeKick(effect_length=torch.tensor(0.5)), + cheetah.Drift(length=torch.tensor(0.25)), + ] + ) incoming_beam = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(10_000), total_charge=torch.tensor(1e-9), @@ -292,27 +307,16 @@ def test_space_charge_with_aperture_cutoff(): sigma_py=torch.tensor(1e-4), ) - drift1 = cheetah.Drift(length=torch.tensor(0.2)) - aperture = cheetah.Aperture( - x_max=torch.tensor(1e-4), - y_max=torch.tensor(1e-4), - shape="rectangular", - is_active="False", - ) - drift2 = cheetah.Drift(length=torch.tensor(0.25)) - space_charge = cheetah.SpaceChargeKick(effect_length=torch.tensor(0.5)) - segment = cheetah.Segment(elements=[drift1, aperture, drift2, space_charge, drift2]) - + # Track with inactive aperture outgoing_beam_without_aperture = segment.track(incoming_beam) - # Activate the aperture - aperture.is_active = True + # Activate the aperture and track the beam + segment.aperture.is_active = True outgoing_beam_with_aperture = segment.track(incoming_beam) # Check that with particle loss the space charge kick is different assert not torch.allclose( outgoing_beam_with_aperture.particles, outgoing_beam_without_aperture.particles ) - # Check that the number of surviving particles is less than the initial number assert outgoing_beam_with_aperture.survival_probabilities.sum(dim=-1).max() < 10_000 From d791fc612c8708dfa0d75124c9e3cf1b8f90de17 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 22:19:16 +0100 Subject: [PATCH 27/29] Adapt changelog message --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0493955..aa05e139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,10 @@ This is a major release with significant upgrades under the hood of Cheetah. Des ### 🚨 Breaking Changes - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan) +- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) - The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) -- Rework the `Aperture` element. Now `ParticleBeam` has a `survival_probabilities` attribute that keeps track of the lost particles. The statistical beam parameters are calculated only w.r.t. surviving particles. Note that the `Aperture` breaks differentiability if activated. (see #268) (@cr-xu, @jank324) ### 🚀 Features From d11da6f21ba4f40552fb3ad3777dae06962af0ab Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 22:20:17 +0100 Subject: [PATCH 28/29] Further adapt changelog message --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa05e139..dec4d198 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des ### 🚨 Breaking Changes - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan) -- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324) +- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). This also comes with the removal of `Beam.empty`. Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324) - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) - The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) From 8aa27abb50acb0d425491459a18a5d948bbbbb1b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 28 Nov 2024 22:35:13 +0100 Subject: [PATCH 29/29] Add screen reading weighting of macroparticles by charge and survival --- cheetah/accelerator/screen.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index c7a56331..93132dd1 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -275,16 +275,24 @@ def reading(self) -> torch.Tensor: ) image, _ = torch.histogramdd( - torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges + torch.stack((read_beam.x, read_beam.y)).T, + bins=self.pixel_bin_edges, + weight=read_beam.particle_charges + * read_beam.survival_probabilities, ) image = torch.flipud(image.T) elif self.method == "kde": + weights = read_beam.particle_charges * read_beam.survival_probabilities + broadcasted_x, broadcasted_y, broadcasted_weights = ( + torch.broadcast_tensors(read_beam.x, read_beam.y, weights) + ) image = kde_histogram_2d( - x1=read_beam.x, - x2=read_beam.y, + x1=broadcasted_x, + x2=broadcasted_y, bins1=self.pixel_bin_centers[0], bins2=self.pixel_bin_centers[1], bandwidth=self.kde_bandwidth, + weights=broadcasted_weights, ) # Change the x, y positions image = torch.transpose(image, -2, -1)