Skip to content

Commit

Permalink
Merge pull request #254 from Hespe/253-default-dtype
Browse files Browse the repository at this point in the history
Improve default dtype selection
  • Loading branch information
jank324 authored Nov 25, 2024
2 parents 0f02af6 + b086391 commit 63b9623
Show file tree
Hide file tree
Showing 29 changed files with 788 additions and 351 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258, #265, #284) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324)

### 🚀 Features

Expand Down
12 changes: 6 additions & 6 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -25,14 +24,15 @@ class Aperture(Element):

def __init__(
self,
x_max: Optional[Union[torch.Tensor, nn.Parameter]] = None,
y_max: Optional[Union[torch.Tensor, nn.Parameter]] = None,
x_max: Optional[torch.Tensor] = None,
y_max: Optional[torch.Tensor] = None,
shape: Literal["rectangular", "elliptical"] = "rectangular",
is_active: bool = True,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([x_max, y_max], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
39 changes: 22 additions & 17 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -30,14 +33,17 @@ class Cavity(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
voltage: Optional[Union[torch.Tensor, nn.Parameter]] = None,
phase: Optional[Union[torch.Tensor, nn.Parameter]] = None,
frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
voltage: Optional[torch.Tensor] = None,
phase: Optional[torch.Tensor] = None,
frequency: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length, voltage, phase, frequency], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -248,13 +254,12 @@ def _track_beam(self, incoming: Beam) -> Beam:

def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
"""Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0."""
device = self.length.device
dtype = self.length.dtype
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

phi = torch.deg2rad(self.phase)
delta_energy = self.voltage * torch.cos(phi)
# Comment from Ocelot: Pure pi-standing-wave case
eta = torch.tensor(1.0, device=device, dtype=dtype)
eta = torch.tensor(1.0, **factory_kwargs)
Ei = energy / electron_mass_eV
Ef = (energy + delta_energy) / electron_mass_eV
Ep = (Ef - Ei) / self.length # Derivative of the energy
Expand Down Expand Up @@ -288,12 +293,12 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
)
)

r56 = torch.tensor(0.0)
beta0 = torch.tensor(1.0)
beta1 = torch.tensor(1.0)
r56 = torch.tensor(0.0, **factory_kwargs)
beta0 = torch.tensor(1.0, **factory_kwargs)
beta1 = torch.tensor(1.0, **factory_kwargs)

k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
r55_cor = torch.tensor(0.0)
k = 2 * torch.pi * self.frequency / constants.speed_of_light
r55_cor = torch.tensor(0.0, **factory_kwargs)
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)
Expand All @@ -320,7 +325,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r11, r12, r21, r22, r55_cor, r56, r65, r66
)

R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1))
R = torch.eye(7, **factory_kwargs).repeat((*r11.shape, 1, 1))
R[..., 0, 0] = r11
R[..., 0, 1] = r12
R[..., 1, 0] = r21
Expand Down
10 changes: 5 additions & 5 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -19,12 +18,13 @@ class CustomTransferMap(Element):

def __init__(
self,
transfer_map: Union[torch.Tensor, nn.Parameter],
transfer_map: torch.Tensor,
length: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
85 changes: 56 additions & 29 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.track_methods import base_rmatrix, rotation_matrix
from cheetah.utils import UniqueNameGenerator, bmadx
from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand Down Expand Up @@ -47,23 +46,39 @@ class Dipole(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
e1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
e2: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
gap: Optional[Union[torch.Tensor, nn.Parameter]] = None,
gap_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None,
fringe_integral: Optional[Union[torch.Tensor, nn.Parameter]] = None,
fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
angle: Optional[torch.Tensor] = None,
k1: Optional[torch.Tensor] = None,
e1: Optional[torch.Tensor] = None,
e2: Optional[torch.Tensor] = None,
tilt: Optional[torch.Tensor] = None,
gap: Optional[torch.Tensor] = None,
gap_exit: Optional[torch.Tensor] = None,
fringe_integral: Optional[torch.Tensor] = None,
fringe_integral_exit: Optional[torch.Tensor] = None,
fringe_at: Literal["neither", "entrance", "exit", "both"] = "both",
fringe_type: Literal["linear_edge"] = "linear_edge",
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
):
device, dtype = verify_device_and_dtype(
[
length,
angle,
k1,
e1,
e2,
tilt,
gap,
gap_exit,
fringe_integral,
fringe_integral_exit,
],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -203,7 +218,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:

# Begin Bmad-X tracking
x, px, y, py = bmadx.offset_particle_set(
torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py
torch.zeros_like(self.tilt),
torch.zeros_like(self.tilt),
self.tilt,
x,
px,
y,
py,
)

if self.fringe_at == "entrance" or self.fringe_at == "both":
Expand All @@ -215,7 +236,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
px, py = self._bmadx_fringe_linear("exit", x, px, y, py)

x, px, y, py = bmadx.offset_particle_unset(
torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py
torch.zeros_like(self.tilt),
torch.zeros_like(self.tilt),
self.tilt,
x,
px,
y,
py,
)
# End of Bmad-X tracking

Expand All @@ -240,15 +267,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:

def _bmadx_body(
self,
x: Union[torch.Tensor, nn.Parameter],
px: Union[torch.Tensor, nn.Parameter],
y: Union[torch.Tensor, nn.Parameter],
py: Union[torch.Tensor, nn.Parameter],
z: Union[torch.Tensor, nn.Parameter],
pz: Union[torch.Tensor, nn.Parameter],
p0c: Union[torch.Tensor, nn.Parameter],
x: torch.Tensor,
px: torch.Tensor,
y: torch.Tensor,
py: torch.Tensor,
z: torch.Tensor,
pz: torch.Tensor,
p0c: torch.Tensor,
mc2: float,
) -> list[Union[torch.Tensor, nn.Parameter]]:
) -> list[torch.Tensor]:
"""
Track particle coordinates through bend body.
Expand Down Expand Up @@ -335,11 +362,11 @@ def _bmadx_body(
def _bmadx_fringe_linear(
self,
location: Literal["entrance", "exit"],
x: Union[torch.Tensor, nn.Parameter],
px: Union[torch.Tensor, nn.Parameter],
y: Union[torch.Tensor, nn.Parameter],
py: Union[torch.Tensor, nn.Parameter],
) -> list[Union[torch.Tensor, nn.Parameter]]:
x: torch.Tensor,
px: torch.Tensor,
y: torch.Tensor,
py: torch.Tensor,
) -> list[torch.Tensor]:
"""
Tracks linear fringe.
Expand Down
7 changes: 3 additions & 4 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand All @@ -28,11 +27,11 @@ class Drift(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
length: torch.Tensor,
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
16 changes: 10 additions & 6 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -25,12 +28,13 @@ class HorizontalCorrector(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
angle: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([length, angle], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
Loading

0 comments on commit 63b9623

Please sign in to comment.