Skip to content

Commit

Permalink
Merge pull request #268 from desy-ml/241-vectorised-aperture-tracking…
Browse files Browse the repository at this point in the history
…-is-broken

Fix `Aperture` vectorisation issue
  • Loading branch information
jank324 authored Nov 28, 2024
2 parents 63b9623 + 8aa27ab commit 6cb0c99
Show file tree
Hide file tree
Showing 23 changed files with 482 additions and 129 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). This also comes with the removal of `Beam.empty`. Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324)
Expand Down
48 changes: 23 additions & 25 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ class Aperture(Element):
"""
Physical aperture.
:param x_max: half size horizontal offset in [m]
:param y_max: half size vertical offset in [m]
NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has
an effect when the aperture is active.
:param x_max: half size horizontal offset in [m].
:param y_max: half size vertical offset in [m].
:param shape: Shape of the aperture. Can be "rectangular" or "elliptical".
:param is_active: If the aperture actually blocks particles.
:param name: Unique identifier of the element.
Expand Down Expand Up @@ -72,41 +75,36 @@ def track(self, incoming: Beam) -> Beam:
if not (isinstance(incoming, ParticleBeam) and self.is_active):
return incoming

assert self.x_max >= 0 and self.y_max >= 0
assert torch.all(self.x_max >= 0) and torch.all(self.y_max >= 0)
assert self.shape in [
"rectangular",
"elliptical",
], f"Unknown aperture shape {self.shape}"

if self.shape == "rectangular":
survived_mask = torch.logical_and(
torch.logical_and(incoming.x > -self.x_max, incoming.x < self.x_max),
torch.logical_and(incoming.y > -self.y_max, incoming.y < self.y_max),
torch.logical_and(
incoming.x > -self.x_max.unsqueeze(-1),
incoming.x < self.x_max.unsqueeze(-1),
),
torch.logical_and(
incoming.y > -self.y_max.unsqueeze(-1),
incoming.y < self.y_max.unsqueeze(-1),
),
)
elif self.shape == "elliptical":
survived_mask = (
incoming.x**2 / self.x_max**2 + incoming.y**2 / self.y_max**2
incoming.x**2 / self.x_max.unsqueeze(-1) ** 2
+ incoming.y**2 / self.y_max.unsqueeze(-1) ** 2
) <= 1.0
outgoing_particles = incoming.particles[survived_mask]

outgoing_particle_charges = incoming.particle_charges[survived_mask]

self.lost_particles = incoming.particles[torch.logical_not(survived_mask)]

self.lost_particle_charges = incoming.particle_charges[
torch.logical_not(survived_mask)
]

return (
ParticleBeam(
outgoing_particles,
incoming.energy,
particle_charges=outgoing_particle_charges,
device=outgoing_particles.device,
dtype=outgoing_particles.dtype,
)
if outgoing_particles.shape[0] > 0
else ParticleBeam.empty
return ParticleBeam(
particles=incoming.particles,
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities * survived_mask,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
Expand Down
4 changes: 1 addition & 3 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
)

def track(self, incoming: Beam) -> Beam:
if incoming is Beam.empty:
self.reading = None
elif isinstance(incoming, ParameterBeam):
if isinstance(incoming, ParameterBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
Expand Down
9 changes: 4 additions & 5 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def track(self, incoming: Beam) -> Beam:
:param incoming: Beam of particles entering the element.
:return: Beam of particles exiting the element.
"""
if incoming is Beam.empty:
return incoming
elif isinstance(incoming, (ParameterBeam, ParticleBeam)):
if isinstance(incoming, (ParameterBeam, ParticleBeam)):
return self._track_beam(incoming)
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")
Expand Down Expand Up @@ -244,9 +242,10 @@ def _track_beam(self, incoming: Beam) -> Beam:
return outgoing
else: # ParticleBeam
outgoing = ParticleBeam(
outgoing_particles,
outgoing_energy,
particles=outgoing_particles,
energy=outgoing_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=outgoing_particles.device,
dtype=outgoing_particles.dtype,
)
Expand Down
1 change: 1 addition & 0 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
1 change: 1 addition & 0 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
5 changes: 2 additions & 3 deletions cheetah/accelerator/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def track(self, incoming: Beam) -> Beam:
:param incoming: Beam of particles entering the element.
:return: Beam of particles exiting the element.
"""
if incoming is Beam.empty:
return incoming
elif isinstance(incoming, ParameterBeam):
if isinstance(incoming, ParameterBeam):
tm = self.transfer_map(incoming.energy)
mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1)))
Expand All @@ -81,6 +79,7 @@ def track(self, incoming: Beam) -> Beam:
new_particles,
incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=new_particles.device,
dtype=new_particles.dtype,
)
Expand Down
7 changes: 5 additions & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
)

outgoing_beam = ParticleBeam(
torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1),
ref_energy,
particles=torch.stack(
(x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
38 changes: 33 additions & 5 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1))

def track(self, incoming: Beam) -> Beam:
# Record the beam only when the screen is active
if self.is_active:
copy_of_incoming = deepcopy(incoming)

Expand Down Expand Up @@ -185,7 +186,26 @@ def track(self, incoming: Beam) -> Beam:

self.set_read_beam(copy_of_incoming)

return Beam.empty if self.is_blocking else incoming
# Block the beam only when the screen is active and blocking
if self.is_active and self.is_blocking:
if isinstance(incoming, ParameterBeam):
return ParameterBeam(
mu=incoming._mu,
cov=incoming._cov,
energy=incoming.energy,
total_charge=torch.zeros_like(incoming.total_charge),
)
elif isinstance(incoming, ParticleBeam):
return ParticleBeam(
particles=incoming.particles,
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=torch.zeros_like(
incoming.survival_probabilities
),
)
else:
return deepcopy(incoming)

@property
def reading(self) -> torch.Tensor:
Expand All @@ -194,7 +214,7 @@ def reading(self) -> torch.Tensor:
return self.cached_reading

read_beam = self.get_read_beam()
if read_beam is Beam.empty or read_beam is None:
if read_beam is None:
image = torch.zeros(
(int(self.effective_resolution[1]), int(self.effective_resolution[0])),
device=self.misalignment.device,
Expand Down Expand Up @@ -255,16 +275,24 @@ def reading(self) -> torch.Tensor:
)

image, _ = torch.histogramdd(
torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges
torch.stack((read_beam.x, read_beam.y)).T,
bins=self.pixel_bin_edges,
weight=read_beam.particle_charges
* read_beam.survival_probabilities,
)
image = torch.flipud(image.T)
elif self.method == "kde":
weights = read_beam.particle_charges * read_beam.survival_probabilities
broadcasted_x, broadcasted_y, broadcasted_weights = (
torch.broadcast_tensors(read_beam.x, read_beam.y, weights)
)
image = kde_histogram_2d(
x1=read_beam.x,
x2=read_beam.y,
x1=broadcasted_x,
x2=broadcasted_y,
bins1=self.pixel_bin_centers[0],
bins2=self.pixel_bin_centers[1],
bandwidth=self.kde_bandwidth,
weights=broadcasted_weights,
)
# Change the x, y positions
image = torch.transpose(image, -2, -1)
Expand Down
90 changes: 50 additions & 40 deletions cheetah/accelerator/space_charge_kick.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.constants import elementary_charge, epsilon_0, speed_of_light

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.particles import ParticleBeam
from cheetah.utils import verify_device_and_dtype


Expand Down Expand Up @@ -149,7 +149,8 @@ def _deposit_charge_on_grid(
)

# Accumulate the charge contributions
repeated_charges = beam.particle_charges.repeat_interleave(
survived_particle_charges = beam.particle_charges * beam.survival_probabilities
repeated_charges = survived_particle_charges.repeat_interleave(
repeats=8, dim=-1
) # Shape:(..., 8 * num_particles)
values = (cell_weights.flatten(start_dim=-2) * repeated_charges)[valid_mask]
Expand Down Expand Up @@ -545,34 +546,44 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam:
:param incoming: Beam of particles entering the element.
:returns: Beam of particles exiting the element.
"""
if incoming is Beam.empty or incoming.particles.shape[0] == 0:
return incoming
elif isinstance(incoming, ParticleBeam):
if isinstance(incoming, ParticleBeam):
# This flattening is a hack to only think about one vector dimension in the
# following code. It is reversed at the end of the function.

# Make sure that the incoming beam has at least one vector dimension
if len(incoming.particles.shape) == 2:
is_incoming_vectorized = False

vectorized_incoming = ParticleBeam(
particles=incoming.particles.unsqueeze(0),
energy=incoming.energy.unsqueeze(0),
particle_charges=incoming.particle_charges.unsqueeze(0),
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
else:
is_incoming_vectorized = True

vectorized_incoming = incoming
# Make sure that the incoming beam has at least one vector dimension by
# broadcasting with a dummy dimension (1,).
vector_shape = torch.broadcast_shapes(
incoming.particles.shape[:-2],
incoming.energy.shape,
incoming.particle_charges.shape[:-1],
incoming.survival_probabilities.shape[:-1],
(1,),
)
vectorized_incoming = ParticleBeam(
particles=torch.broadcast_to(
incoming.particles, (*vector_shape, incoming.num_particles, 7)
),
energy=torch.broadcast_to(incoming.energy, vector_shape),
particle_charges=torch.broadcast_to(
incoming.particle_charges, (*vector_shape, incoming.num_particles)
),
survival_probabilities=torch.broadcast_to(
incoming.survival_probabilities,
(*vector_shape, incoming.num_particles),
),
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

flattened_incoming = ParticleBeam(
particles=vectorized_incoming.particles.flatten(end_dim=-3),
energy=vectorized_incoming.energy.flatten(end_dim=-1),
particle_charges=vectorized_incoming.particle_charges.flatten(
end_dim=-2
),
survival_probabilities=(
vectorized_incoming.survival_probabilities.flatten(end_dim=-2)
),
device=vectorized_incoming.particles.device,
dtype=vectorized_incoming.particles.dtype,
)
Expand Down Expand Up @@ -611,26 +622,25 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam:
..., 2
] * dt.unsqueeze(-1)

if not is_incoming_vectorized:
# Reshape to the original non-vectorised shape
outgoing = ParticleBeam.from_xyz_pxpypz(
xp_coordinates.squeeze(0),
vectorized_incoming.energy.squeeze(0),
vectorized_incoming.particle_charges.squeeze(0),
vectorized_incoming.particles.device,
vectorized_incoming.particles.dtype,
)
else:
# Reverse the flattening of the vector dimensions
outgoing = ParticleBeam.from_xyz_pxpypz(
xp_coordinates.unflatten(
dim=0, sizes=vectorized_incoming.particles.shape[:-2]
),
vectorized_incoming.energy,
vectorized_incoming.particle_charges,
vectorized_incoming.particles.device,
vectorized_incoming.particles.dtype,
)
# Reverse the flattening of the vector dimensions
outgoing_vector_shape = torch.broadcast_shapes(
incoming.particles.shape[:-2],
incoming.energy.shape,
incoming.particle_charges.shape[:-1],
incoming.survival_probabilities.shape[:-1],
self.effect_length.shape,
)
outgoing = ParticleBeam.from_xyz_pxpypz(
xp_coordinates=xp_coordinates.reshape(
(*outgoing_vector_shape, incoming.num_particles, 7)
),
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

return outgoing
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")
Expand Down
7 changes: 5 additions & 2 deletions cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
)

outgoing_beam = ParticleBeam(
torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1),
ref_energy,
particles=torch.stack(
(x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
2 changes: 0 additions & 2 deletions cheetah/particles/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ class directly, but use one of the subclasses.
:math:`\Delta E = E - E_0`
"""

empty = "I'm an empty beam!"

@classmethod
@abstractmethod
def from_parameters(
Expand Down
Loading

0 comments on commit 6cb0c99

Please sign in to comment.