Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Aperture vectorisation issue #268

Merged
merged 34 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dada294
Add test to detect `Aperture` vectorisation issue
jank324 Oct 3, 2024
e5cc55d
Minor fixes to the vectorised aperture test itself
jank324 Oct 3, 2024
d9691cc
Extend test to cover both aperture shapes
jank324 Oct 3, 2024
b175b78
Change logical accumulators to torch ones
jank324 Oct 3, 2024
09f7b1b
Merge branch 'master' into 241-vectorised-aperture-tracking-is-broken
cr-xu Oct 24, 2024
3f5d48f
Add `particle_survival` probability to the `ParticleBeam`
cr-xu Oct 24, 2024
247b64c
Include particle survival into the statistical beam parameters calcul…
cr-xu Oct 24, 2024
5a79ed3
Add tests for variance calculation
cr-xu Oct 25, 2024
0d7d07d
Account for the lost particles in space charge effects
cr-xu Oct 25, 2024
c8a1a02
Update `CHANGELOG.md`
cr-xu Oct 25, 2024
0cd8eda
Update CHANGELOG.md
cr-xu Oct 25, 2024
d89deb8
Merge branch 'master' into 241-vectorised-aperture-tracking-is-broken
jank324 Nov 20, 2024
31b320d
Merge branch 'master' into 241-vectorised-aperture-tracking-is-broken
jank324 Nov 20, 2024
1b5ce33
Merge branch 'master' into 241-vectorised-aperture-tracking-is-broken
jank324 Nov 22, 2024
07f4f43
Merge branch 'master' into 241-vectorised-aperture-tracking-is-broken
jank324 Nov 25, 2024
157f6dd
More descriptive property name
jank324 Nov 28, 2024
0172baa
Align the way survival probabilities are instatiated with other prope…
jank324 Nov 28, 2024
e729a06
More consistent argument order
jank324 Nov 28, 2024
da13b30
Fix `black` warning because of too long line
jank324 Nov 28, 2024
1caf5ff
Clean up aperture test
jank324 Nov 28, 2024
73657cd
Fix `Aperture` test and clean up `Aperture` code
jank324 Nov 28, 2024
fa3f0a6
Some code cleanup in `ParameterBeam`
jank324 Nov 28, 2024
52116ab
Remove `Beam.empty`
jank324 Nov 28, 2024
bd6515f
Fix `flake8` warning about unused import
jank324 Nov 28, 2024
db6db83
Minor code readiblity improvement
jank324 Nov 28, 2024
7a538d3
Some more readibility imporvements to the code
jank324 Nov 28, 2024
9e1189f
Clean up statistics tests
jank324 Nov 28, 2024
f1ca213
Refactor assertions for readability in test_compare_ocelot.py
jank324 Nov 28, 2024
2e44fb5
Properly fix Bmad-X test beam
jank324 Nov 28, 2024
3f18d18
Rename to surval probabilities properties
jank324 Nov 28, 2024
fdfe717
Minor test code cleanup
jank324 Nov 28, 2024
d791fc6
Adapt changelog message
jank324 Nov 28, 2024
d11da6f
Further adapt changelog message
jank324 Nov 28, 2024
8aa27ab
Add screen reading weighting of macroparticles by charge and survival
jank324 Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). This also comes with the removal of `Beam.empty`. Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324)
Expand Down
48 changes: 23 additions & 25 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ class Aperture(Element):
"""
Physical aperture.

:param x_max: half size horizontal offset in [m]
:param y_max: half size vertical offset in [m]
NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has
an effect when the aperture is active.

:param x_max: half size horizontal offset in [m].
:param y_max: half size vertical offset in [m].
:param shape: Shape of the aperture. Can be "rectangular" or "elliptical".
:param is_active: If the aperture actually blocks particles.
:param name: Unique identifier of the element.
Expand Down Expand Up @@ -72,41 +75,36 @@ def track(self, incoming: Beam) -> Beam:
if not (isinstance(incoming, ParticleBeam) and self.is_active):
return incoming

assert self.x_max >= 0 and self.y_max >= 0
assert torch.all(self.x_max >= 0) and torch.all(self.y_max >= 0)
assert self.shape in [
"rectangular",
"elliptical",
], f"Unknown aperture shape {self.shape}"

if self.shape == "rectangular":
survived_mask = torch.logical_and(
torch.logical_and(incoming.x > -self.x_max, incoming.x < self.x_max),
torch.logical_and(incoming.y > -self.y_max, incoming.y < self.y_max),
torch.logical_and(
incoming.x > -self.x_max.unsqueeze(-1),
incoming.x < self.x_max.unsqueeze(-1),
),
torch.logical_and(
incoming.y > -self.y_max.unsqueeze(-1),
incoming.y < self.y_max.unsqueeze(-1),
),
)
elif self.shape == "elliptical":
survived_mask = (
incoming.x**2 / self.x_max**2 + incoming.y**2 / self.y_max**2
incoming.x**2 / self.x_max.unsqueeze(-1) ** 2
+ incoming.y**2 / self.y_max.unsqueeze(-1) ** 2
) <= 1.0
outgoing_particles = incoming.particles[survived_mask]

outgoing_particle_charges = incoming.particle_charges[survived_mask]

self.lost_particles = incoming.particles[torch.logical_not(survived_mask)]

self.lost_particle_charges = incoming.particle_charges[
torch.logical_not(survived_mask)
]

return (
ParticleBeam(
outgoing_particles,
incoming.energy,
particle_charges=outgoing_particle_charges,
device=outgoing_particles.device,
dtype=outgoing_particles.dtype,
)
if outgoing_particles.shape[0] > 0
else ParticleBeam.empty
return ParticleBeam(
particles=incoming.particles,
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities * survived_mask,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
Expand Down
4 changes: 1 addition & 3 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
)

def track(self, incoming: Beam) -> Beam:
if incoming is Beam.empty:
self.reading = None
elif isinstance(incoming, ParameterBeam):
if isinstance(incoming, ParameterBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
Expand Down
9 changes: 4 additions & 5 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def track(self, incoming: Beam) -> Beam:
:param incoming: Beam of particles entering the element.
:return: Beam of particles exiting the element.
"""
if incoming is Beam.empty:
return incoming
elif isinstance(incoming, (ParameterBeam, ParticleBeam)):
if isinstance(incoming, (ParameterBeam, ParticleBeam)):
return self._track_beam(incoming)
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")
Expand Down Expand Up @@ -244,9 +242,10 @@ def _track_beam(self, incoming: Beam) -> Beam:
return outgoing
else: # ParticleBeam
outgoing = ParticleBeam(
outgoing_particles,
outgoing_energy,
particles=outgoing_particles,
energy=outgoing_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=outgoing_particles.device,
dtype=outgoing_particles.dtype,
)
Expand Down
1 change: 1 addition & 0 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
1 change: 1 addition & 0 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
5 changes: 2 additions & 3 deletions cheetah/accelerator/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def track(self, incoming: Beam) -> Beam:
:param incoming: Beam of particles entering the element.
:return: Beam of particles exiting the element.
"""
if incoming is Beam.empty:
return incoming
elif isinstance(incoming, ParameterBeam):
if isinstance(incoming, ParameterBeam):
tm = self.transfer_map(incoming.energy)
mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1)))
Expand All @@ -81,6 +79,7 @@ def track(self, incoming: Beam) -> Beam:
new_particles,
incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=new_particles.device,
dtype=new_particles.dtype,
)
Expand Down
7 changes: 5 additions & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
)

outgoing_beam = ParticleBeam(
torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1),
ref_energy,
particles=torch.stack(
(x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
38 changes: 33 additions & 5 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1))

def track(self, incoming: Beam) -> Beam:
# Record the beam only when the screen is active
if self.is_active:
copy_of_incoming = deepcopy(incoming)

Expand Down Expand Up @@ -185,7 +186,26 @@ def track(self, incoming: Beam) -> Beam:

self.set_read_beam(copy_of_incoming)

return Beam.empty if self.is_blocking else incoming
# Block the beam only when the screen is active and blocking
if self.is_active and self.is_blocking:
if isinstance(incoming, ParameterBeam):
return ParameterBeam(
mu=incoming._mu,
cov=incoming._cov,
energy=incoming.energy,
total_charge=torch.zeros_like(incoming.total_charge),
)
elif isinstance(incoming, ParticleBeam):
return ParticleBeam(
particles=incoming.particles,
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=torch.zeros_like(
incoming.survival_probabilities
),
)
else:
return deepcopy(incoming)

@property
def reading(self) -> torch.Tensor:
Expand All @@ -194,7 +214,7 @@ def reading(self) -> torch.Tensor:
return self.cached_reading

read_beam = self.get_read_beam()
if read_beam is Beam.empty or read_beam is None:
if read_beam is None:
image = torch.zeros(
(int(self.effective_resolution[1]), int(self.effective_resolution[0])),
device=self.misalignment.device,
Expand Down Expand Up @@ -255,16 +275,24 @@ def reading(self) -> torch.Tensor:
)

image, _ = torch.histogramdd(
torch.stack((read_beam.x, read_beam.y)).T, bins=self.pixel_bin_edges
torch.stack((read_beam.x, read_beam.y)).T,
bins=self.pixel_bin_edges,
weight=read_beam.particle_charges
* read_beam.survival_probabilities,
)
image = torch.flipud(image.T)
elif self.method == "kde":
weights = read_beam.particle_charges * read_beam.survival_probabilities
broadcasted_x, broadcasted_y, broadcasted_weights = (
torch.broadcast_tensors(read_beam.x, read_beam.y, weights)
)
image = kde_histogram_2d(
x1=read_beam.x,
x2=read_beam.y,
x1=broadcasted_x,
x2=broadcasted_y,
bins1=self.pixel_bin_centers[0],
bins2=self.pixel_bin_centers[1],
bandwidth=self.kde_bandwidth,
weights=broadcasted_weights,
)
# Change the x, y positions
image = torch.transpose(image, -2, -1)
Expand Down
90 changes: 50 additions & 40 deletions cheetah/accelerator/space_charge_kick.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.constants import elementary_charge, epsilon_0, speed_of_light

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.particles import ParticleBeam
from cheetah.utils import verify_device_and_dtype


Expand Down Expand Up @@ -149,7 +149,8 @@ def _deposit_charge_on_grid(
)

# Accumulate the charge contributions
repeated_charges = beam.particle_charges.repeat_interleave(
survived_particle_charges = beam.particle_charges * beam.survival_probabilities
repeated_charges = survived_particle_charges.repeat_interleave(
repeats=8, dim=-1
) # Shape:(..., 8 * num_particles)
values = (cell_weights.flatten(start_dim=-2) * repeated_charges)[valid_mask]
Expand Down Expand Up @@ -545,34 +546,44 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam:
:param incoming: Beam of particles entering the element.
:returns: Beam of particles exiting the element.
"""
if incoming is Beam.empty or incoming.particles.shape[0] == 0:
return incoming
elif isinstance(incoming, ParticleBeam):
if isinstance(incoming, ParticleBeam):
# This flattening is a hack to only think about one vector dimension in the
# following code. It is reversed at the end of the function.

# Make sure that the incoming beam has at least one vector dimension
if len(incoming.particles.shape) == 2:
is_incoming_vectorized = False

vectorized_incoming = ParticleBeam(
particles=incoming.particles.unsqueeze(0),
energy=incoming.energy.unsqueeze(0),
particle_charges=incoming.particle_charges.unsqueeze(0),
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
else:
is_incoming_vectorized = True

vectorized_incoming = incoming
# Make sure that the incoming beam has at least one vector dimension by
# broadcasting with a dummy dimension (1,).
vector_shape = torch.broadcast_shapes(
incoming.particles.shape[:-2],
incoming.energy.shape,
incoming.particle_charges.shape[:-1],
incoming.survival_probabilities.shape[:-1],
(1,),
)
vectorized_incoming = ParticleBeam(
particles=torch.broadcast_to(
incoming.particles, (*vector_shape, incoming.num_particles, 7)
),
energy=torch.broadcast_to(incoming.energy, vector_shape),
particle_charges=torch.broadcast_to(
incoming.particle_charges, (*vector_shape, incoming.num_particles)
),
survival_probabilities=torch.broadcast_to(
incoming.survival_probabilities,
(*vector_shape, incoming.num_particles),
),
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

flattened_incoming = ParticleBeam(
particles=vectorized_incoming.particles.flatten(end_dim=-3),
energy=vectorized_incoming.energy.flatten(end_dim=-1),
particle_charges=vectorized_incoming.particle_charges.flatten(
end_dim=-2
),
survival_probabilities=(
vectorized_incoming.survival_probabilities.flatten(end_dim=-2)
),
device=vectorized_incoming.particles.device,
dtype=vectorized_incoming.particles.dtype,
)
Expand Down Expand Up @@ -611,26 +622,25 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam:
..., 2
] * dt.unsqueeze(-1)

if not is_incoming_vectorized:
# Reshape to the original non-vectorised shape
outgoing = ParticleBeam.from_xyz_pxpypz(
xp_coordinates.squeeze(0),
vectorized_incoming.energy.squeeze(0),
vectorized_incoming.particle_charges.squeeze(0),
vectorized_incoming.particles.device,
vectorized_incoming.particles.dtype,
)
else:
# Reverse the flattening of the vector dimensions
outgoing = ParticleBeam.from_xyz_pxpypz(
xp_coordinates.unflatten(
dim=0, sizes=vectorized_incoming.particles.shape[:-2]
),
vectorized_incoming.energy,
vectorized_incoming.particle_charges,
vectorized_incoming.particles.device,
vectorized_incoming.particles.dtype,
)
# Reverse the flattening of the vector dimensions
outgoing_vector_shape = torch.broadcast_shapes(
incoming.particles.shape[:-2],
incoming.energy.shape,
incoming.particle_charges.shape[:-1],
incoming.survival_probabilities.shape[:-1],
self.effect_length.shape,
)
outgoing = ParticleBeam.from_xyz_pxpypz(
xp_coordinates=xp_coordinates.reshape(
(*outgoing_vector_shape, incoming.num_particles, 7)
),
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

return outgoing
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")
Expand Down
7 changes: 5 additions & 2 deletions cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,12 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
)

outgoing_beam = ParticleBeam(
torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1),
ref_energy,
particles=torch.stack(
(x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
Expand Down
2 changes: 0 additions & 2 deletions cheetah/particles/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ class directly, but use one of the subclasses.
:math:`\Delta E = E - E_0`
"""

empty = "I'm an empty beam!"

@classmethod
@abstractmethod
def from_parameters(
Expand Down
Loading