Skip to content

Commit

Permalink
Merge branch 'master' into get-item-method
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 authored Nov 28, 2024
2 parents 6250845 + 6cb0c99 commit 55ae402
Show file tree
Hide file tree
Showing 38 changed files with 1,269 additions and 479 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- As part of the vectorised rewrite, the `Aperture` no longer removes particles. Instead, `ParticleBeam.survival_probabilities` tracks the probability that a particle has survived (i.e. the inverse probability that it has been lost). This also comes with the removal of `Beam.empty`. Note that particle losses in `Aperture` are currently not differentiable. This will be addressed in a future release. (see #268) (@cr-xu, @jank324)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324)

### 🚀 Features

Expand Down
60 changes: 29 additions & 31 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -16,23 +15,27 @@ class Aperture(Element):
"""
Physical aperture.
:param x_max: half size horizontal offset in [m]
:param y_max: half size vertical offset in [m]
NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has
an effect when the aperture is active.
:param x_max: half size horizontal offset in [m].
:param y_max: half size vertical offset in [m].
:param shape: Shape of the aperture. Can be "rectangular" or "elliptical".
:param is_active: If the aperture actually blocks particles.
:param name: Unique identifier of the element.
"""

def __init__(
self,
x_max: Optional[Union[torch.Tensor, nn.Parameter]] = None,
y_max: Optional[Union[torch.Tensor, nn.Parameter]] = None,
x_max: Optional[torch.Tensor] = None,
y_max: Optional[torch.Tensor] = None,
shape: Literal["rectangular", "elliptical"] = "rectangular",
is_active: bool = True,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([x_max, y_max], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -72,41 +75,36 @@ def track(self, incoming: Beam) -> Beam:
if not (isinstance(incoming, ParticleBeam) and self.is_active):
return incoming

assert self.x_max >= 0 and self.y_max >= 0
assert torch.all(self.x_max >= 0) and torch.all(self.y_max >= 0)
assert self.shape in [
"rectangular",
"elliptical",
], f"Unknown aperture shape {self.shape}"

if self.shape == "rectangular":
survived_mask = torch.logical_and(
torch.logical_and(incoming.x > -self.x_max, incoming.x < self.x_max),
torch.logical_and(incoming.y > -self.y_max, incoming.y < self.y_max),
torch.logical_and(
incoming.x > -self.x_max.unsqueeze(-1),
incoming.x < self.x_max.unsqueeze(-1),
),
torch.logical_and(
incoming.y > -self.y_max.unsqueeze(-1),
incoming.y < self.y_max.unsqueeze(-1),
),
)
elif self.shape == "elliptical":
survived_mask = (
incoming.x**2 / self.x_max**2 + incoming.y**2 / self.y_max**2
incoming.x**2 / self.x_max.unsqueeze(-1) ** 2
+ incoming.y**2 / self.y_max.unsqueeze(-1) ** 2
) <= 1.0
outgoing_particles = incoming.particles[survived_mask]

outgoing_particle_charges = incoming.particle_charges[survived_mask]

self.lost_particles = incoming.particles[torch.logical_not(survived_mask)]

self.lost_particle_charges = incoming.particle_charges[
torch.logical_not(survived_mask)
]

return (
ParticleBeam(
outgoing_particles,
incoming.energy,
particle_charges=outgoing_particle_charges,
device=outgoing_particles.device,
dtype=outgoing_particles.dtype,
)
if outgoing_particles.shape[0] > 0
else ParticleBeam.empty
return ParticleBeam(
particles=incoming.particles,
energy=incoming.energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities * survived_mask,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
Expand Down
4 changes: 1 addition & 3 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
)

def track(self, incoming: Beam) -> Beam:
if incoming is Beam.empty:
self.reading = None
elif isinstance(incoming, ParameterBeam):
if isinstance(incoming, ParameterBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
Expand Down
48 changes: 26 additions & 22 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -30,14 +33,17 @@ class Cavity(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
voltage: Optional[Union[torch.Tensor, nn.Parameter]] = None,
phase: Optional[Union[torch.Tensor, nn.Parameter]] = None,
frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
voltage: Optional[torch.Tensor] = None,
phase: Optional[torch.Tensor] = None,
frequency: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length, voltage, phase, frequency], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -97,9 +103,7 @@ def track(self, incoming: Beam) -> Beam:
:param incoming: Beam of particles entering the element.
:return: Beam of particles exiting the element.
"""
if incoming is Beam.empty:
return incoming
elif isinstance(incoming, (ParameterBeam, ParticleBeam)):
if isinstance(incoming, (ParameterBeam, ParticleBeam)):
return self._track_beam(incoming)
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")
Expand Down Expand Up @@ -238,23 +242,23 @@ def _track_beam(self, incoming: Beam) -> Beam:
return outgoing
else: # ParticleBeam
outgoing = ParticleBeam(
outgoing_particles,
outgoing_energy,
particles=outgoing_particles,
energy=outgoing_energy,
particle_charges=incoming.particle_charges,
survival_probabilities=incoming.survival_probabilities,
device=outgoing_particles.device,
dtype=outgoing_particles.dtype,
)
return outgoing

def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
"""Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0."""
device = self.length.device
dtype = self.length.dtype
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

phi = torch.deg2rad(self.phase)
delta_energy = self.voltage * torch.cos(phi)
# Comment from Ocelot: Pure pi-standing-wave case
eta = torch.tensor(1.0, device=device, dtype=dtype)
eta = torch.tensor(1.0, **factory_kwargs)
Ei = energy / electron_mass_eV
Ef = (energy + delta_energy) / electron_mass_eV
Ep = (Ef - Ei) / self.length # Derivative of the energy
Expand Down Expand Up @@ -288,12 +292,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
)
)

r56 = torch.tensor(0.0)
beta0 = torch.tensor(1.0)
beta1 = torch.tensor(1.0)
r56 = torch.tensor(0.0, **factory_kwargs)
beta0 = torch.tensor(1.0, **factory_kwargs)
beta1 = torch.tensor(1.0, **factory_kwargs)

k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
r55_cor = torch.tensor(0.0)
k = 2 * torch.pi * self.frequency / constants.speed_of_light
r55_cor = torch.tensor(0.0, **factory_kwargs)
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)
Expand All @@ -320,7 +324,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r11, r12, r21, r22, r55_cor, r56, r65, r66
)

R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1))
R = torch.eye(7, **factory_kwargs).repeat((*r11.shape, 1, 1))
R[..., 0, 0] = r11
R[..., 0, 1] = r12
R[..., 1, 0] = r21
Expand Down
10 changes: 5 additions & 5 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -19,12 +18,13 @@ class CustomTransferMap(Element):

def __init__(
self,
transfer_map: Union[torch.Tensor, nn.Parameter],
transfer_map: torch.Tensor,
length: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
Loading

0 comments on commit 55ae402

Please sign in to comment.