Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add traveling-wave cavity model #286

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
81 changes: 58 additions & 23 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Cavity(Element):
:param phase: Phase of the cavity in degrees.
:param frequency: Frequency of the cavity in Hz.
:param name: Unique identifier of the element.
:param cavity_type: Type of the cavity.
"""

def __init__(
Expand All @@ -37,6 +38,7 @@ def __init__(
voltage: Optional[torch.Tensor] = None,
phase: Optional[torch.Tensor] = None,
frequency: Optional[torch.Tensor] = None,
cavity_type: Optional[str] = "standing_wave",
name: Optional[str] = None,
device=None,
dtype=None,
Expand All @@ -58,6 +60,8 @@ def __init__(
self.phase = torch.as_tensor(phase, **factory_kwargs)
if frequency is not None:
self.frequency = torch.as_tensor(frequency, **factory_kwargs)

self.cavity_type = cavity_type

@property
def is_active(self) -> bool:
Expand Down Expand Up @@ -248,35 +252,65 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
Ei = energy / electron_mass_eV
Ef = (energy + delta_energy) / electron_mass_eV
Ep = (Ef - Ei) / self.length # Derivative of the energy
dE = Ef - Ei
assert torch.all(Ei > 0), "Initial energy must be larger than 0"

alpha = torch.sqrt(eta / 8) / torch.cos(phi) * torch.log(Ef / Ei)

r11 = torch.cos(alpha) - torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)

# In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length
# otherwise. This is implemented differently here in order to achieve results
# closer to Bmad.
r12 = torch.sqrt(8 / eta) * Ei / Ep * torch.cos(phi) * torch.sin(alpha)

r21 = (
-Ep
/ Ef
* (
torch.cos(phi) / torch.sqrt(2 * eta)
+ torch.sqrt(eta / 8) / torch.cos(phi)
if self.cavity_type == 'standing_wave':
r11 = torch.cos(alpha)
- torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)

# In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length
# otherwise. This is implemented differently here to achieve results
# closer to Bmad.
r12 = torch.sqrt(8 / eta) * Ei / Ep * torch.cos(phi) * torch.sin(alpha)

r21 = (
-Ep
/ Ef
* (
torch.cos(phi) / torch.sqrt(2 * eta)
+ torch.sqrt(eta / 8) / torch.cos(phi)
)
* torch.sin(alpha)
)
* torch.sin(alpha)
)

r22 = (
Ei
/ Ef
* (
torch.cos(alpha)
+ torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)
r22 = (
Ei
/ Ef
* (
torch.cos(alpha)
+ torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)
)
)
)

if self.cavity_type == 'traveling_wave':
# reference paper:Rosenzweig and Serafini, PhysRevE, Vol.49, p.1599,(1994)
f = (Ei / dE) * torch.log(1 + (dE / Ei))
Mbody = torch.tensor([
[1, self.length * f],
[0, Ei / Ef]
], device=device, dtype=dtype)

Mfent = torch.tensor([
[1, 0],
[-dE / (2 * self.length * Ei), 1]
], device=device, dtype=dtype)

Mfexit = torch.tensor([
[1, 0],
[dE / (2 * self.length * Ef), 1]
], device=device, dtype=dtype)
result = Mfexit @ Mbody @ Mfent

r11 = result[0, 0]
r12 = result[0, 1]
r21 = result[1, 0]
r22 = result[1, 1]

else:
raise ValueError(f"Unrecognized cavity type: '{self.cavity_type}'. Valid types are 'standing_wave' and 'Traveling_wave'.")

r56 = torch.tensor(0.0, **factory_kwargs)
beta0 = torch.tensor(1.0, **factory_kwargs)
Expand Down Expand Up @@ -345,13 +379,14 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["length", "voltage", "phase", "frequency"]
return super().defining_features + ["length", "voltage", "phase", "frequency", "cavity_type"]

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(length={repr(self.length)}, "
+ f"voltage={repr(self.voltage)}, "
+ f"phase={repr(self.phase)}, "
+ f"frequency={repr(self.frequency)}, "
+ f"cavity_type={repr(self.cavity_type)}, "
+ f"name={repr(self.name)})"
)
1 change: 1 addition & 0 deletions cheetah/converters/bmad.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def convert_element(
-np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi)
),
frequency=torch.tensor(bmad_parsed["rf_frequency"]),
cavity_type=bmad_parsed["cavity_type"],
name=name,
device=device,
dtype=dtype,
Expand Down
Loading