Skip to content

Commit

Permalink
Merge branch 'master' into distribution-plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 authored Dec 10, 2024
2 parents 0e311e9 + a784bb0 commit 21d279b
Show file tree
Hide file tree
Showing 19 changed files with 253 additions and 314 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- Fix issue with Dipole hgap conversion in Bmad import (see #261) (@cr-xu)
- Fix plotting for segments that contain tensors with `require_grad=True` (see #288) (@hespe)
- Fix bug where `Element.length` could not be set as a `torch.nn.Parameter` (see #301) (@jank324, @hespe)
- Fix registration of `torch.nn.Parameter` at initilization for elements and beams (see #303) (@hespe)

### 🐆 Other

Expand Down
28 changes: 10 additions & 18 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,16 @@ def __init__(
) -> None:
device, dtype = verify_device_and_dtype([x_max, y_max], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer(
"x_max",
(
torch.as_tensor(x_max, **factory_kwargs)
if x_max is not None
else torch.tensor(float("inf"), **factory_kwargs)
),
)
self.register_buffer(
"y_max",
(
torch.as_tensor(y_max, **factory_kwargs)
if y_max is not None
else torch.tensor(float("inf"), **factory_kwargs)
),
)
super().__init__(name=name, **factory_kwargs)

self.register_buffer("x_max", torch.tensor(float("inf"), **factory_kwargs))
self.register_buffer("y_max", torch.tensor(float("inf"), **factory_kwargs))

if x_max is not None:
self.x_max = torch.as_tensor(x_max, **factory_kwargs)
if y_max is not None:
self.y_max = torch.as_tensor(y_max, **factory_kwargs)

self.shape = shape
self.is_active = is_active

Expand Down
40 changes: 13 additions & 27 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,19 @@ def __init__(
[length, voltage, phase, frequency], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.register_buffer(
"voltage",
(
torch.as_tensor(voltage, **factory_kwargs)
if voltage is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"phase",
(
torch.as_tensor(phase, **factory_kwargs)
if phase is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"frequency",
(
torch.as_tensor(frequency, **factory_kwargs)
if frequency is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
super().__init__(name=name, **factory_kwargs)

self.register_buffer("voltage", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("phase", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("frequency", torch.tensor(0.0, **factory_kwargs))

self.length = torch.as_tensor(length, **factory_kwargs)
if voltage is not None:
self.voltage = torch.as_tensor(voltage, **factory_kwargs)
if phase is not None:
self.phase = torch.as_tensor(phase, **factory_kwargs)
if frequency is not None:
self.frequency = torch.as_tensor(frequency, **factory_kwargs)

@property
def is_active(self) -> bool:
Expand Down
19 changes: 7 additions & 12 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,18 @@ def __init__(
[predefined_transfer_map, length], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
super().__init__(name=name, **factory_kwargs)

assert isinstance(predefined_transfer_map, torch.Tensor)
assert predefined_transfer_map.shape[-2:] == (7, 7)

self.register_buffer(
"predefined_transfer_map",
torch.as_tensor(predefined_transfer_map, **factory_kwargs),
)
self.register_buffer(
"length",
(
torch.as_tensor(length, **factory_kwargs)
if length is not None
else torch.zeros(predefined_transfer_map.shape[:-2], **factory_kwargs)
),
self.register_buffer("predefined_transfer_map", None)

self.predefined_transfer_map = torch.as_tensor(
predefined_transfer_map, **factory_kwargs
)
if length is not None:
self.length = torch.as_tensor(length, **factory_kwargs)

@classmethod
def from_merging_elements(
Expand Down
109 changes: 36 additions & 73 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,81 +80,44 @@ def __init__(
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.register_buffer(
"angle",
(
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.tensor(0.0, **factory_kwargs)
),
super().__init__(name=name, **factory_kwargs)

self.register_buffer("angle", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("k1", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("_e1", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("_e2", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("fringe_integral", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("fringe_integral_exit", None)
self.register_buffer("gap", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("gap_exit", None)
self.register_buffer("tilt", torch.tensor(0.0, **factory_kwargs))

self.length = torch.as_tensor(length, **factory_kwargs)
if angle is not None:
self.angle = torch.as_tensor(angle, **factory_kwargs)
if k1 is not None:
self.k1 = torch.as_tensor(k1, **factory_kwargs)
if dipole_e1 is not None:
self._e1 = torch.as_tensor(dipole_e1, **factory_kwargs)
if dipole_e2 is not None:
self._e2 = torch.as_tensor(dipole_e2, **factory_kwargs)
if fringe_integral is not None:
self.fringe_integral = torch.as_tensor(fringe_integral, **factory_kwargs)
self.fringe_integral_exit = (
torch.tensor(fringe_integral_exit, **factory_kwargs)
if fringe_integral_exit is not None
else self.fringe_integral
)
self.register_buffer(
"k1",
(
torch.as_tensor(k1, **factory_kwargs)
if k1 is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"_e1",
(
torch.as_tensor(dipole_e1, **factory_kwargs)
if dipole_e1 is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"_e2",
(
torch.as_tensor(dipole_e2, **factory_kwargs)
if dipole_e2 is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"fringe_integral",
(
torch.as_tensor(fringe_integral, **factory_kwargs)
if fringe_integral is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"fringe_integral_exit",
(
self.fringe_integral
if fringe_integral_exit is None
else torch.tensor(fringe_integral_exit, **factory_kwargs)
),
)
self.register_buffer(
"gap",
(
torch.as_tensor(gap, **factory_kwargs)
if gap is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"gap_exit",
(
torch.as_tensor(gap_exit, **factory_kwargs)
if gap_exit is not None
else 1.0 * self.gap
),
)
self.register_buffer(
"tilt",
(
torch.as_tensor(tilt, **factory_kwargs)
if tilt is not None
else torch.tensor(0.0, **factory_kwargs)
),
if gap is not None:
self.gap = torch.as_tensor(gap, **factory_kwargs)
self.gap_exit = (
torch.tensor(gap_exit, **factory_kwargs)
if gap_exit is not None
else self.gap
)
if tilt is not None:
self.tilt = torch.as_tensor(tilt, **factory_kwargs)

self.fringe_at = fringe_at
self.fringe_type = fringe_type
self.tracking_method = tracking_method
Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
super().__init__(name=name, **factory_kwargs)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.length = torch.as_tensor(length, **factory_kwargs)
self.tracking_method = tracking_method

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class Element(ABC, nn.Module):
:param name: Unique identifier of the element.
"""

def __init__(self, name: Optional[str] = None) -> None:
def __init__(self, name: Optional[str] = None, device=None, dtype=None) -> None:
super().__init__()

self.name = name if name is not None else generate_unique_name()
self.register_buffer("length", torch.tensor(0.0))
self.register_buffer("length", torch.tensor(0.0, device=device, dtype=dtype))

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
r"""
Expand Down
18 changes: 7 additions & 11 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ def __init__(
) -> None:
device, dtype = verify_device_and_dtype([length, angle], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.register_buffer(
"angle",
(
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
super().__init__(name=name, **factory_kwargs)

self.register_buffer("angle", torch.tensor(0.0, **factory_kwargs))

self.length = torch.as_tensor(length, **factory_kwargs)
if angle is not None:
self.angle = torch.as_tensor(angle, **factory_kwargs)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
device = self.length.device
Expand Down
41 changes: 14 additions & 27 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,20 @@ def __init__(
[length, k1, misalignment, tilt], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.register_buffer(
"k1",
(
torch.as_tensor(k1, **factory_kwargs)
if k1 is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"misalignment",
(
torch.as_tensor(misalignment, **factory_kwargs)
if misalignment is not None
else torch.zeros((*self.length.shape, 2), **factory_kwargs)
),
)
self.register_buffer(
"tilt",
(
torch.as_tensor(tilt, **factory_kwargs)
if tilt is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
super().__init__(name=name, **factory_kwargs)

self.register_buffer("k1", torch.tensor(0.0, **factory_kwargs))
self.register_buffer("misalignment", torch.tensor((0.0, 0.0), **factory_kwargs))
self.register_buffer("tilt", torch.tensor(0.0, **factory_kwargs))

self.length = torch.as_tensor(length, **factory_kwargs)
if k1 is not None:
self.k1 = torch.as_tensor(k1, **factory_kwargs)
if misalignment is not None:
self.misalignment = torch.as_tensor(misalignment, **factory_kwargs)
if tilt is not None:
self.tilt = torch.as_tensor(tilt, **factory_kwargs)

self.num_steps = num_steps
self.tracking_method = tracking_method

Expand Down
40 changes: 11 additions & 29 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
[pixel_size, misalignment, kde_bandwidth], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
super().__init__(name=name, **factory_kwargs)

assert method in [
"histogram",
Expand All @@ -70,34 +70,16 @@ def __init__(
self.is_blocking = is_blocking
self.is_active = is_active

self.register_buffer(
"pixel_size",
(
torch.as_tensor(pixel_size, **factory_kwargs)
if pixel_size is not None
else torch.tensor((1e-3, 1e-3), **factory_kwargs)
),
)
self.register_buffer(
"misalignment",
(
torch.as_tensor(misalignment, **factory_kwargs)
if misalignment is not None
else torch.tensor((0.0, 0.0), **factory_kwargs)
),
)
self.register_buffer(
"length",
torch.zeros(self.misalignment.shape[:-1], **factory_kwargs),
)
self.register_buffer(
"kde_bandwidth",
(
torch.as_tensor(kde_bandwidth, **factory_kwargs)
if kde_bandwidth is not None
else torch.clone(self.pixel_size[0])
),
)
self.register_buffer("pixel_size", torch.tensor((1e-3, 1e-3), **factory_kwargs))
self.register_buffer("misalignment", torch.tensor((0.0, 0.0), **factory_kwargs))
self.register_buffer("kde_bandwidth", torch.clone(self.pixel_size[0]))

if pixel_size is not None:
self.pixel_size = torch.as_tensor(pixel_size, **factory_kwargs)
if misalignment is not None:
self.misalignment = torch.as_tensor(misalignment, **factory_kwargs)
if kde_bandwidth is not None:
self.kde_bandwidth = torch.as_tensor(kde_bandwidth, **factory_kwargs)

self.set_read_beam(None)
self.cached_reading = None
Expand Down
Loading

0 comments on commit 21d279b

Please sign in to comment.