Skip to content

Commit

Permalink
Merge branch '253-default-dtype' of https://github.com/Hespe/cheetah
Browse files Browse the repository at this point in the history
…into 253-default-dtype
  • Loading branch information
jank324 committed Nov 20, 2024
2 parents 8b7e654 + 329500e commit d80f4db
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 95 deletions.
7 changes: 3 additions & 4 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand All @@ -25,8 +24,8 @@ class Aperture(Element):

def __init__(
self,
x_max: Optional[Union[torch.Tensor, nn.Parameter]] = None,
y_max: Optional[Union[torch.Tensor, nn.Parameter]] = None,
x_max: Optional[torch.Tensor] = None,
y_max: Optional[torch.Tensor] = None,
shape: Literal["rectangular", "elliptical"] = "rectangular",
is_active: bool = True,
name: Optional[str] = None,
Expand Down
11 changes: 5 additions & 6 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from scipy import constants
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
Expand Down Expand Up @@ -34,10 +33,10 @@ class Cavity(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
voltage: Optional[Union[torch.Tensor, nn.Parameter]] = None,
phase: Optional[Union[torch.Tensor, nn.Parameter]] = None,
frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
voltage: Optional[torch.Tensor] = None,
phase: Optional[torch.Tensor] = None,
frequency: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=None,
Expand Down
5 changes: 2 additions & 3 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam
Expand All @@ -19,7 +18,7 @@ class CustomTransferMap(Element):

def __init__(
self,
transfer_map: Union[torch.Tensor, nn.Parameter],
transfer_map: torch.Tensor,
length: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
Expand Down
49 changes: 24 additions & 25 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand Down Expand Up @@ -47,16 +46,16 @@ class Dipole(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
e1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
e2: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
gap: Optional[Union[torch.Tensor, nn.Parameter]] = None,
gap_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None,
fringe_integral: Optional[Union[torch.Tensor, nn.Parameter]] = None,
fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
angle: Optional[torch.Tensor] = None,
k1: Optional[torch.Tensor] = None,
e1: Optional[torch.Tensor] = None,
e2: Optional[torch.Tensor] = None,
tilt: Optional[torch.Tensor] = None,
gap: Optional[torch.Tensor] = None,
gap_exit: Optional[torch.Tensor] = None,
fringe_integral: Optional[torch.Tensor] = None,
fringe_integral_exit: Optional[torch.Tensor] = None,
fringe_at: Literal["neither", "entrance", "exit", "both"] = "both",
fringe_type: Literal["linear_edge"] = "linear_edge",
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
Expand Down Expand Up @@ -268,15 +267,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:

def _bmadx_body(
self,
x: Union[torch.Tensor, nn.Parameter],
px: Union[torch.Tensor, nn.Parameter],
y: Union[torch.Tensor, nn.Parameter],
py: Union[torch.Tensor, nn.Parameter],
z: Union[torch.Tensor, nn.Parameter],
pz: Union[torch.Tensor, nn.Parameter],
p0c: Union[torch.Tensor, nn.Parameter],
x: torch.Tensor,
px: torch.Tensor,
y: torch.Tensor,
py: torch.Tensor,
z: torch.Tensor,
pz: torch.Tensor,
p0c: torch.Tensor,
mc2: float,
) -> list[Union[torch.Tensor, nn.Parameter]]:
) -> list[torch.Tensor]:
"""
Track particle coordinates through bend body.
Expand Down Expand Up @@ -363,11 +362,11 @@ def _bmadx_body(
def _bmadx_fringe_linear(
self,
location: Literal["entrance", "exit"],
x: Union[torch.Tensor, nn.Parameter],
px: Union[torch.Tensor, nn.Parameter],
y: Union[torch.Tensor, nn.Parameter],
py: Union[torch.Tensor, nn.Parameter],
) -> list[Union[torch.Tensor, nn.Parameter]]:
x: torch.Tensor,
px: torch.Tensor,
y: torch.Tensor,
py: torch.Tensor,
) -> list[torch.Tensor]:
"""
Tracks linear fringe.
Expand Down
5 changes: 2 additions & 3 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand All @@ -28,7 +27,7 @@ class Drift(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
length: torch.Tensor,
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
Expand Down
7 changes: 3 additions & 4 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.utils import (
Expand All @@ -29,8 +28,8 @@ class HorizontalCorrector(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
angle: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=None,
Expand Down
11 changes: 5 additions & 6 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand Down Expand Up @@ -34,10 +33,10 @@ class Quadrupole(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
k1: Optional[torch.Tensor] = None,
misalignment: Optional[torch.Tensor] = None,
tilt: Optional[torch.Tensor] = None,
num_steps: int = 1,
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
Expand Down
21 changes: 10 additions & 11 deletions cheetah/accelerator/rbend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional, Union
from typing import Optional

import torch
from torch import nn

from cheetah.accelerator.dipole import Dipole
from cheetah.utils import UniqueNameGenerator
Expand All @@ -28,15 +27,15 @@ class RBend(Dipole):

def __init__(
self,
length: Optional[Union[torch.Tensor, nn.Parameter]],
angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
e1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
e2: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
fringe_integral: Optional[Union[torch.Tensor, nn.Parameter]] = None,
fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None,
gap: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: Optional[torch.Tensor],
angle: Optional[torch.Tensor] = None,
k1: Optional[torch.Tensor] = None,
e1: Optional[torch.Tensor] = None,
e2: Optional[torch.Tensor] = None,
tilt: Optional[torch.Tensor] = None,
fringe_integral: Optional[torch.Tensor] = None,
fringe_integral_exit: Optional[torch.Tensor] = None,
gap: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=None,
Expand Down
9 changes: 4 additions & 5 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from copy import deepcopy
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from torch import nn
from torch.distributions import MultivariateNormal

from cheetah.accelerator.element import Element
Expand Down Expand Up @@ -44,11 +43,11 @@ class Screen(Element):
def __init__(
self,
resolution: Optional[tuple[int, int]] = None,
pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None,
pixel_size: Optional[torch.Tensor] = None,
binning: Optional[int] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[torch.Tensor] = None,
method: Literal["histogram", "kde"] = "histogram",
kde_bandwidth: Optional[Union[torch.Tensor, nn.Parameter]] = None,
kde_bandwidth: Optional[torch.Tensor] = None,
is_blocking: bool = False,
is_active: bool = False,
name: Optional[str] = None,
Expand Down
9 changes: 4 additions & 5 deletions cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.track_methods import misalignment_matrix
Expand Down Expand Up @@ -35,9 +34,9 @@ class Solenoid(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter] = None,
k: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor = None,
k: Optional[torch.Tensor] = None,
misalignment: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=None,
Expand Down
13 changes: 5 additions & 8 deletions cheetah/accelerator/space_charge_kick.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from scipy.constants import elementary_charge, epsilon_0, speed_of_light
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand Down Expand Up @@ -48,15 +47,13 @@ class SpaceChargeKick(Element):

def __init__(
self,
effect_length: Union[torch.Tensor, nn.Parameter],
effect_length: torch.Tensor,
num_grid_points_x: int = 32, # TODO: Simplify these to a single tuple?
num_grid_points_y: int = 32,
num_grid_points_tau: int = 32,
grid_extend_x: Union[
torch.Tensor, nn.Parameter
] = 3, # TODO: Simplify these to a single tensor?
grid_extend_y: Union[torch.Tensor, nn.Parameter] = 3,
grid_extend_tau: Union[torch.Tensor, nn.Parameter] = 3,
grid_extend_x: torch.Tensor = 3, # TODO: Simplify these to a single tensor?
grid_extend_y: torch.Tensor = 3,
grid_extend_tau: torch.Tensor = 3,
name: Optional[str] = None,
device=None,
dtype=None,
Expand Down
15 changes: 7 additions & 8 deletions cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants, speed_of_light
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
Expand Down Expand Up @@ -34,12 +33,12 @@ class TransverseDeflectingCavity(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
voltage: Optional[Union[torch.Tensor, nn.Parameter]] = None,
phase: Optional[Union[torch.Tensor, nn.Parameter]] = None,
frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
length: torch.Tensor,
voltage: Optional[torch.Tensor] = None,
phase: Optional[torch.Tensor] = None,
frequency: Optional[torch.Tensor] = None,
misalignment: Optional[torch.Tensor] = None,
tilt: Optional[torch.Tensor] = None,
num_steps: int = 1,
tracking_method: Literal["bmadx"] = "bmadx",
name: Optional[str] = None,
Expand Down
5 changes: 2 additions & 3 deletions cheetah/accelerator/undulator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Optional, Union
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.utils import UniqueNameGenerator
Expand All @@ -28,7 +27,7 @@ class Undulator(Element):

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
length: torch.Tensor,
is_active: bool = False,
name: Optional[str] = None,
device=None,
Expand Down
Loading

0 comments on commit d80f4db

Please sign in to comment.