Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add particle species in Beam classes and update tracking methods #276

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,28 @@ jobs:
include:
- os: macos-latest
python-version: "3.12"
- os: windows-latest
python-version: "3.12"
runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: 'latest'
conda-solver: 'classic'
python-version: ${{ matrix.python-version }}
- name: Install dependencies
activate-environment: test_environment
- name: Install Bmad and Pytao
shell: bash -el {0}
run: |
conda install -c conda-forge bmad
conda install -c conda-forge pytao
- name: Install pip dependencies
shell: bash -el {0}
run: |
python -m pip install --upgrade pip
pip install -e .
pip install -r test_requirements.txt
- name: Test with pytest
shell: bash -el {0}
run: |
pytest
30 changes: 30 additions & 0 deletions .github/workflows/pytest_windows.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: pytest_windows

on:
push:
branches: [master]
pull_request:
branches: [master]

jobs:
build:
strategy:
matrix:
os: [windows-latest]
python-version: ["3.12"]
runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install -r test_requirements.txt
- name: Test with pytest
run: |
pytest --ignore=tests/bmad_benchmarks
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278) (@jp-ga, @cr-xu, @jank324)
- `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe)
- Implement a converter for lattice files imported from Elegant (see #222, #251, #273, #281) (@hespe, @jank324)
- Add the option to choose the particle species `Species` for `Beam` classes (see #276) (@cr-xu)

### 🐛 Bug fixes

Expand Down
4 changes: 3 additions & 1 deletion cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(
def is_skippable(self) -> bool:
return not self.is_active

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
def transfer_map(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
device = self.x_max.device
dtype = self.x_max.dtype

Expand Down
4 changes: 3 additions & 1 deletion cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, is_active: bool = False, name: Optional[str] = None) -> None:
def is_skippable(self) -> bool:
return not self.is_active

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
def transfer_map(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
return torch.eye(7, device=energy.device, dtype=energy.dtype).repeat(
(*energy.shape, 1, 1)
)
Expand Down
29 changes: 17 additions & 12 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
Expand All @@ -14,8 +13,6 @@

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6


class Cavity(Element):
"""
Expand All @@ -26,6 +23,9 @@ class Cavity(Element):
:param phase: Phase of the cavity in degrees.
:param frequency: Frequency of the cavity in Hz.
:param name: Unique identifier of the element.

Note: here we use the "effective" voltage, if the particle does not have unit
charge, the voltage needs to be scaled properly.
"""

def __init__(
Expand Down Expand Up @@ -75,14 +75,17 @@ def is_active(self) -> bool:
def is_skippable(self) -> bool:
return not self.is_active

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
def transfer_map(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
return torch.where(
(self.voltage != 0).unsqueeze(-1).unsqueeze(-1),
self._cavity_rmatrix(energy),
self._cavity_rmatrix(energy, particle_mass_eV),
base_rmatrix(
length=self.length,
k1=torch.zeros_like(self.length),
hx=torch.zeros_like(self.length),
particle_mass_eV=particle_mass_eV,
tilt=torch.zeros_like(self.length),
energy=energy,
),
Expand Down Expand Up @@ -113,7 +116,7 @@ def _track_beam(self, incoming: Beam) -> Beam:

phi = torch.deg2rad(self.phase)

tm = self.transfer_map(incoming.energy)
tm = self.transfer_map(incoming.energy, incoming.mass_eV)
if isinstance(incoming, ParameterBeam):
outgoing_mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
outgoing_cov = torch.matmul(
Expand Down Expand Up @@ -161,7 +164,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
- torch.cos(phi).unsqueeze(-1)
)

dgamma = self.voltage / electron_mass_eV
dgamma = self.voltage / incoming.mass_eV
if torch.any(delta_energy > 0):
T566 = (
self.length
Expand Down Expand Up @@ -246,7 +249,9 @@ def _track_beam(self, incoming: Beam) -> Beam:
)
return outgoing

def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
def _cavity_rmatrix(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
"""Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0."""
device = self.length.device
dtype = self.length.dtype
Expand All @@ -255,8 +260,8 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
delta_energy = self.voltage * torch.cos(phi)
# Comment from Ocelot: Pure pi-standing-wave case
eta = torch.tensor(1.0, device=device, dtype=dtype)
Ei = energy / electron_mass_eV
Ef = (energy + delta_energy) / electron_mass_eV
Ei = energy / particle_mass_eV
Ef = (energy + delta_energy) / particle_mass_eV
Ep = (Ef - Ei) / self.length # Derivative of the energy
assert torch.all(Ei > 0), "Initial energy must be larger than 0"

Expand Down Expand Up @@ -306,14 +311,14 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
* self.length
* beta0
* self.voltage
/ electron_mass_eV
/ particle_mass_eV
* torch.sin(phi)
* (g0 * g1 * (beta0 * beta1 - 1) + 1)
/ (beta1 * g1 * (g0 - g1) ** 2)
)

r66 = Ei / Ef * beta0 / beta1
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV)
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * particle_mass_eV)

# Make sure that all matrix elements have the same shape
r11, r12, r21, r22, r55_cor, r56, r65, r66 = torch.broadcast_tensors(
Expand Down
16 changes: 12 additions & 4 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,20 @@ def from_merging_elements(
" incorrect tracking results."
)

device = elements[0].transfer_map(incoming_beam.energy).device
dtype = elements[0].transfer_map(incoming_beam.energy).dtype
device = (
elements[0].transfer_map(incoming_beam.energy, incoming_beam.mass_eV).device
)
dtype = (
elements[0].transfer_map(incoming_beam.energy, incoming_beam.mass_eV).dtype
)

tm = torch.eye(7, device=device, dtype=dtype).repeat(
(*incoming_beam.energy.shape, 1, 1)
)
for element in elements:
tm = torch.matmul(element.transfer_map(incoming_beam.energy), tm)
tm = torch.matmul(
element.transfer_map(incoming_beam.energy, incoming_beam.mass_eV), tm
)
incoming_beam = element.track(incoming_beam)

combined_length = sum(element.length for element in elements)
Expand All @@ -82,7 +88,9 @@ def from_merging_elements(
tm, length=combined_length, device=device, dtype=dtype, name=combined_name
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
def transfer_map(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
return self._transfer_map

@property
Expand Down
22 changes: 9 additions & 13 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
Expand All @@ -14,8 +13,6 @@

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6


class Dipole(Element):
"""
Expand Down Expand Up @@ -196,10 +193,9 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
py = incoming.py
tau = incoming.tau
delta = incoming.p
mc2 = incoming.mass_eV

z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(
tau, delta, incoming.energy, electron_mass_eV
)
z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(tau, delta, incoming.energy, mc2)

# Begin Bmad-X tracking
x, px, y, py = bmadx.offset_particle_set(
Expand All @@ -208,9 +204,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:

if self.fringe_at == "entrance" or self.fringe_at == "both":
px, py = self._bmadx_fringe_linear("entrance", x, px, y, py)
x, px, y, py, z, pz = self._bmadx_body(
x, px, y, py, z, pz, p0c, electron_mass_eV
)
x, px, y, py, z, pz = self._bmadx_body(x, px, y, py, z, pz, p0c, mc2)
if self.fringe_at == "exit" or self.fringe_at == "both":
px, py = self._bmadx_fringe_linear("exit", x, px, y, py)

Expand All @@ -220,9 +214,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
# End of Bmad-X tracking

# Convert back to Cheetah coordinates
tau, delta, ref_energy = bmadx.bmad_to_cheetah_z_pz(
z, pz, p0c, electron_mass_eV
)
tau, delta, ref_energy = bmadx.bmad_to_cheetah_z_pz(z, pz, p0c, mc2)

# Broadcast to align their shapes so that they can be stacked
x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta)
Expand All @@ -235,6 +227,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
particle_charges=incoming.particle_charges,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
species=incoming.species,
)
return outgoing_beam

Expand Down Expand Up @@ -368,7 +361,9 @@ def _bmadx_fringe_linear(

return px_f, py_f

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
def transfer_map(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

Expand All @@ -380,6 +375,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
length=self.length,
k1=self.k1,
hx=self.hx,
particle_mass_eV=particle_mass_eV,
tilt=torch.zeros_like(self.length),
energy=energy,
) # Tilt is applied after adding edges
Expand Down
16 changes: 8 additions & 8 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import matplotlib.pyplot as plt
import torch
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
Expand All @@ -11,8 +10,6 @@

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = physical_constants["electron mass energy equivalent in MeV"][0] * 1e6


class Drift(Element):
"""
Expand Down Expand Up @@ -40,11 +37,13 @@ def __init__(
self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.tracking_method = tracking_method

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
def transfer_map(
self, energy: torch.Tensor, particle_mass_eV: float
) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

_, igamma2, beta = compute_relativistic_factors(energy)
_, igamma2, beta = compute_relativistic_factors(energy, particle_mass_eV)

vector_shape = torch.broadcast_shapes(self.length.shape, igamma2.shape)

Expand Down Expand Up @@ -92,18 +91,18 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
delta = incoming.p

z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(
tau, delta, incoming.energy, electron_mass_eV
tau, delta, incoming.energy, incoming.mass_eV
)

# Begin Bmad-X tracking
x, y, z = bmadx.track_a_drift(
self.length, x, px, y, py, z, pz, p0c, electron_mass_eV
self.length, x, px, y, py, z, pz, p0c, incoming.mass_eV
)
# End of Bmad-X tracking

# Convert back to Cheetah coordinates
tau, delta, ref_energy = bmadx.bmad_to_cheetah_z_pz(
z, pz, p0c, electron_mass_eV
z, pz, p0c, incoming.mass_eV
)

# Broadcast to align their shapes so that they can be stacked
Expand All @@ -117,6 +116,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
particle_charges=incoming.particle_charges,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
species=incoming.species,
)
return outgoing_beam

Expand Down
Loading